| |
| |
| """ |
| Pure Qwen3 autoregressive training/inference for music token generation. |
| |
| This variant keeps: |
| - style |
| - section structure in the prompt |
| - section lyrics/description |
| |
| This variant removes: |
| - frame-level chord conditioning |
| - frame-level structure conditioning |
| - condition encoder / AdaLN injection |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| from dataclasses import dataclass |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Any |
|
|
| import datasets |
| import numpy as np |
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import AutoConfig, AutoTokenizer, Trainer, TrainingArguments |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM |
|
|
| from audio_tokens import ( |
| EOA_TOKEN, |
| MASK_AUDIO_TOKEN, |
| SOA_TOKEN, |
| add_audio_special_tokens, |
| audio_id_to_token, |
| ) |
| from dataset import normalize_section_text, SECTION_NAME_MAP, SINGLETON_SECTION_NAMES |
| from inference_full import build_mucodec_decoder, decode_mucodec_codes |
| from runtime_utils import resolve_device, seed_everything |
| from vocab import normalize_structure_label |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Pure Qwen3 autoregressive training/inference without frame-level conditioning." |
| ) |
| subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
| train_parser = subparsers.add_parser( |
| "train", |
| help="Train a plain Qwen3 autoregressive model on section prompts and audio tokens.", |
| ) |
| add_train_args(train_parser) |
|
|
| infer_parser = subparsers.add_parser( |
| "infer", |
| help="Run section-wise autoregressive inference with a plain Qwen3 checkpoint.", |
| ) |
| add_infer_args(infer_parser) |
|
|
| return parser.parse_args() |
|
|
|
|
| def add_train_args(parser: argparse.ArgumentParser) -> None: |
| parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds") |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| default="checkpoints/Qwen3-0.6B", |
| help="Local Qwen3 base checkpoint path.", |
| ) |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default="checkpoints/Qwen3-0.6B", |
| help="Local tokenizer checkpoint path.", |
| ) |
| parser.add_argument( |
| "--num_audio_token", |
| type=int, |
| default=None, |
| help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.", |
| ) |
| parser.add_argument( |
| "--model_dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| ) |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="sdpa", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
| parser.add_argument("--output_dir", type=str, default="./output_qwen3_plain_ar") |
| parser.add_argument("--per_device_train_batch_size", type=int, default=1) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--num_train_epochs", type=float, default=20) |
| parser.add_argument("--warmup_steps", type=int, default=1000) |
| parser.add_argument("--max_grad_norm", type=float, default=5.0) |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument( |
| "--resume_from_checkpoint", |
| type=str, |
| default=None, |
| help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.", |
| ) |
| parser.add_argument("--dataloader_num_workers", type=int, default=12) |
| parser.add_argument( |
| "--gradient_checkpointing", |
| dest="gradient_checkpointing", |
| action="store_true", |
| ) |
| parser.add_argument( |
| "--deepspeed", |
| type=str, |
| default=None, |
| help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.", |
| ) |
| parser.add_argument("--report_to", type=str, default="wandb") |
| parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b") |
| parser.add_argument("--wandb_run_name", type=str, default=None) |
|
|
|
|
| def add_infer_args(parser: argparse.ArgumentParser) -> None: |
| parser.add_argument("--model_path", type=str, required=True) |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default=None, |
| help="Tokenizer path. Defaults to --model_path.", |
| ) |
| parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds") |
| parser.add_argument("--split", type=str, default="validation") |
| parser.add_argument("--sample_idx", type=int, default=0) |
| parser.add_argument( |
| "--num_audio_token", |
| type=int, |
| default=None, |
| help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.", |
| ) |
| parser.add_argument("--seed", type=int, default=1234) |
| parser.add_argument("--device", type=str, default="auto") |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| ) |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="sdpa", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
| parser.add_argument("--temperature", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--top_p", type=float, default=0.90) |
| parser.add_argument("--greedy", action="store_true", default=False) |
| parser.add_argument("--use_cache", action="store_true", default=True) |
| parser.add_argument("--no_cache", action="store_true", default=False) |
| parser.add_argument( |
| "--max_new_tokens_per_section", |
| type=int, |
| default=2048, |
| help="Upper bound for each section decode before forcing a failure.", |
| ) |
| parser.add_argument("--output_dir", type=str, default="plain_ar_predictions") |
| parser.add_argument("--output_prefix", type=str, default="") |
| parser.add_argument("--skip_decode", action="store_true", default=False) |
| parser.add_argument("--mucodec_device", type=str, default="auto") |
| parser.add_argument("--mucodec_layer_num", type=int, default=7) |
| parser.add_argument("--mucodec_duration", type=float, default=40.96) |
| parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5) |
| parser.add_argument("--mucodec_num_steps", type=int, default=20) |
| parser.add_argument("--mucodec_sample_rate", type=int, default=48000) |
|
|
|
|
| def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str: |
| if not resume_from_checkpoint: |
| return model_path |
| if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint): |
| print( |
| "Ignoring --model_path during resume and loading config/model from: " |
| f"{resume_from_checkpoint}" |
| ) |
| return resume_from_checkpoint |
|
|
|
|
| def get_model_dtype(name: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[name] |
|
|
|
|
| def resolve_num_audio_token(checkpoint_path: str, explicit_value: int | None) -> int: |
| if explicit_value is not None: |
| return int(explicit_value) |
| config = AutoConfig.from_pretrained( |
| checkpoint_path, |
| local_files_only=True, |
| ) |
| return int(getattr(config, "magel_num_audio_token", 16384)) |
|
|
|
|
| @dataclass |
| class PreparedSection: |
| text: str |
| desc: str |
| start_frame: int |
| end_frame: int |
| structure: str |
| index: int |
|
|
|
|
| class PlainARDataCollator: |
| def __init__(self, pad_token_id: int = 0): |
| self.pad_token_id = int(pad_token_id) |
|
|
| def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: |
| return { |
| "input_ids": pad_sequence( |
| [row["input_ids"] for row in batch], |
| batch_first=True, |
| padding_value=self.pad_token_id, |
| ), |
| "attention_mask": pad_sequence( |
| [row["attention_mask"] for row in batch], |
| batch_first=True, |
| padding_value=0, |
| ), |
| "labels": pad_sequence( |
| [row["labels"] for row in batch], |
| batch_first=True, |
| padding_value=-100, |
| ), |
| } |
|
|
|
|
| class PlainARMusicDataset(torch.utils.data.Dataset): |
| def __init__( |
| self, |
| datasets_obj, |
| split: str, |
| tokenizer_path: str, |
| num_audio_token: int = 16384, |
| fps: int = 25, |
| use_fast: bool = True, |
| ): |
| self._data = datasets_obj[split] |
| self.tokenizer_path = tokenizer_path |
| self.num_audio_token = int(num_audio_token) |
| self.fps = int(fps) |
| self.use_fast = bool(use_fast) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, |
| local_files_only=True, |
| use_fast=use_fast, |
| ) |
| add_audio_special_tokens(self.tokenizer, self.num_audio_token) |
| self.tokenizer_vocab_size = len(self.tokenizer) |
|
|
| self.audio_prefix_length = int( |
| self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0)) |
| ) |
| self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN)) |
| self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN)) |
| self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN)) |
| self.pad_token_id = ( |
| int(self.tokenizer.pad_token_id) |
| if self.tokenizer.pad_token_id is not None |
| else 0 |
| ) |
| self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}" |
| self._chat_template_kwargs = {"enable_thinking": False} |
|
|
| def __len__(self) -> int: |
| return len(self._data) |
|
|
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: |
| sample = self.raw_sample(idx) |
| sections = self.prepare_sections(sample) |
| token_ids, attention_mask = self.tokenize_messages( |
| self.build_messages(sample, sections), |
| sample["mucodec_codes"], |
| sections, |
| ) |
| labels = self.build_labels(token_ids) |
|
|
| return { |
| "input_ids": token_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |
|
|
| def raw_sample(self, idx: int) -> dict[str, Any]: |
| return self._data[idx] |
|
|
| @staticmethod |
| def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor: |
| return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1) |
|
|
| @staticmethod |
| def _sorted_sections(sample: dict[str, Any]) -> list[dict[str, Any]]: |
| return sorted( |
| ( |
| { |
| "raw_index": raw_index, |
| "text": str(seg["text"]), |
| "desc": str(seg["desc"]).strip(), |
| "start": float(seg["start"]), |
| "end": float(seg["end"]), |
| "structure": normalize_structure_label(seg["section"]), |
| } |
| for raw_index, seg in enumerate(sample.get("sections", [])) |
| ), |
| key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), |
| ) |
|
|
| def prepare_sections(self, sample: dict[str, Any]) -> list[PreparedSection]: |
| sections: list[PreparedSection] = [] |
| section_counts: dict[str, int] = {} |
| total_frames = len(sample["mucodec_codes"]) |
| prev_end_idx = 0 |
| sample_language = sample.get("language") |
|
|
| for seg in self._sorted_sections(sample): |
| structure = seg["structure"] |
| section_counts[structure] = section_counts.get(structure, 0) + 1 |
| raw_end_idx = max( |
| prev_end_idx, |
| min(total_frames, int(np.ceil(seg["end"] * self.fps))), |
| ) |
| sections.append( |
| PreparedSection( |
| text=normalize_section_text( |
| seg["text"], |
| structure, |
| language=sample_language, |
| ), |
| desc=seg["desc"], |
| start_frame=prev_end_idx, |
| end_frame=raw_end_idx, |
| structure=structure, |
| index=section_counts[structure], |
| ) |
| ) |
| prev_end_idx = raw_end_idx |
|
|
| if sections: |
| last = sections[-1] |
| sections[-1] = PreparedSection( |
| text=last.text, |
| desc=last.desc, |
| start_frame=last.start_frame, |
| end_frame=total_frames, |
| structure=last.structure, |
| index=last.index, |
| ) |
|
|
| return sections |
|
|
| def format_section_label(self, section: PreparedSection) -> str: |
| label = SECTION_NAME_MAP[section.structure] |
| if section.structure in SINGLETON_SECTION_NAMES and section.index == 1: |
| return label |
| return f"{label} {section.index}" |
|
|
| def build_section_user_content( |
| self, |
| sample: dict[str, Any], |
| section: PreparedSection, |
| is_first_turn: bool, |
| ) -> str: |
| parts: list[str] = [] |
| if is_first_turn: |
| style = str(sample.get("style", "")).strip() |
| if style: |
| parts.append( |
| f"Please generate a song in the following style:{style}\n" |
| "Next, I will tell you the requirements and lyrics for the song " |
| "fragment to be generated, section by section." |
| ) |
| else: |
| parts.append( |
| "Please generate the song section by section. " |
| "Next, I will tell you the requirements and lyrics for each fragment." |
| ) |
|
|
| section_parts = [f"[{self.format_section_label(section)}]"] |
| if section.desc: |
| section_parts.append(f"[desc:{section.desc}]") |
| if section.text: |
| section_parts.append(f"[lyrics:{section.text}]") |
| parts.append("".join(section_parts)) |
| return "\n".join(part for part in parts if part) |
|
|
| def build_messages( |
| self, |
| sample: dict[str, Any], |
| sections: list[PreparedSection], |
| ) -> list[dict[str, str]]: |
| messages: list[dict[str, str]] = [] |
| for idx, section in enumerate(sections): |
| messages.append( |
| { |
| "role": "user", |
| "content": self.build_section_user_content( |
| sample=sample, |
| section=section, |
| is_first_turn=(idx == 0), |
| ), |
| } |
| ) |
| messages.append( |
| { |
| "role": "assistant", |
| "content": self._assistant_audio_placeholder, |
| } |
| ) |
| return messages |
|
|
| def tokenize_messages( |
| self, |
| messages: list[dict[str, str]], |
| full_audio_codes, |
| sections: list[PreparedSection], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| chat_inputs = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=False, |
| return_tensors="pt", |
| return_dict=True, |
| **self._chat_template_kwargs, |
| ) |
| token_ids = chat_inputs["input_ids"].squeeze(0).to(torch.long) |
| attention_mask = chat_inputs["attention_mask"].squeeze(0).to(torch.long) |
| return self.expand_audio_tokens( |
| token_ids=token_ids, |
| attention_mask=attention_mask, |
| full_audio_codes=full_audio_codes, |
| sections=sections, |
| ) |
|
|
| def expand_audio_tokens( |
| self, |
| token_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| full_audio_codes, |
| sections: list[PreparedSection], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if not sections: |
| return token_ids, attention_mask |
|
|
| bos_positions = self._positions(token_ids, self.BOS_AUDIO) |
| eos_positions = self._positions(token_ids, self.EOS_AUDIO) |
| audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long) |
|
|
| extra_audio_tokens = sum( |
| int(section.end_frame) - int(section.start_frame) |
| for section in sections |
| ) |
| final_len = token_ids.numel() + extra_audio_tokens |
|
|
| expanded_token_ids = torch.empty(final_len, dtype=torch.long) |
| expanded_attention_mask = torch.empty(final_len, dtype=torch.long) |
| read_pos = 0 |
| write_pos = 0 |
|
|
| for bos_pos, eos_pos, section in zip( |
| bos_positions.tolist(), |
| eos_positions.tolist(), |
| sections, |
| ): |
| start_idx = int(section.start_frame) |
| end_idx = int(section.end_frame) |
| audio_len = max(0, end_idx - start_idx) |
|
|
| prefix_len = bos_pos + 1 - read_pos |
| next_write = write_pos + prefix_len |
| expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1] |
| expanded_attention_mask[write_pos:next_write] = attention_mask[ |
| read_pos : bos_pos + 1 |
| ] |
| write_pos = next_write |
|
|
| if audio_len > 0: |
| next_write = write_pos + audio_len |
| expanded_token_ids[write_pos:next_write] = audio_code_tensor[ |
| start_idx:end_idx |
| ] |
| expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length) |
| expanded_attention_mask[write_pos:next_write] = 1 |
| write_pos = next_write |
|
|
| expanded_token_ids[write_pos] = token_ids[eos_pos] |
| expanded_attention_mask[write_pos] = attention_mask[eos_pos] |
| write_pos += 1 |
| read_pos = eos_pos + 1 |
|
|
| tail_len = token_ids.numel() - read_pos |
| if tail_len > 0: |
| expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:] |
| expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[ |
| read_pos: |
| ] |
|
|
| return expanded_token_ids, expanded_attention_mask |
|
|
| def build_labels(self, token_ids: torch.Tensor) -> torch.Tensor: |
| audio_codebook_mask = (token_ids >= self.audio_prefix_length) & ( |
| token_ids < self.MASK_AUDIO |
| ) |
| eos_mask = token_ids == self.EOS_AUDIO |
| label_mask = audio_codebook_mask | eos_mask |
|
|
| labels = token_ids.clone() |
| labels[~label_mask] = -100 |
| return labels |
|
|
| def render_audio_token_string( |
| self, |
| audio_token_ids: list[int], |
| include_bos: bool = True, |
| include_eos: bool = True, |
| ) -> str: |
| parts: list[str] = [] |
| if include_bos: |
| parts.append(SOA_TOKEN) |
| parts.extend(self.tokenizer.convert_ids_to_tokens(audio_token_ids)) |
| if include_eos: |
| parts.append(EOA_TOKEN) |
| return "".join(parts) |
|
|
|
|
| def create_plain_qwen3_model( |
| model_path: str, |
| model_dtype: torch.dtype, |
| attn_implementation: str, |
| target_vocab_size: int, |
| ) -> Qwen3ForCausalLM: |
| print(f"Loading Qwen3 model from: {model_path}") |
| config = AutoConfig.from_pretrained( |
| model_path, |
| local_files_only=True, |
| ) |
| model = Qwen3ForCausalLM.from_pretrained( |
| model_path, |
| config=config, |
| torch_dtype=model_dtype, |
| attn_implementation=attn_implementation, |
| ignore_mismatched_sizes=True, |
| local_files_only=True, |
| ) |
| model.resize_token_embeddings(target_vocab_size) |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
| print(f"Total parameters: {total_params:,}") |
| print(f"Trainable parameters: {trainable_params:,}") |
| return model |
|
|
|
|
| def load_plain_qwen3_for_inference( |
| model_path: str, |
| device: torch.device, |
| dtype: torch.dtype, |
| attn_implementation: str, |
| target_vocab_size: int, |
| ) -> Qwen3ForCausalLM: |
| model = create_plain_qwen3_model( |
| model_path=model_path, |
| model_dtype=dtype, |
| attn_implementation=attn_implementation, |
| target_vocab_size=target_vocab_size, |
| ) |
| model.to(device=device) |
| model.eval() |
| return model |
|
|
|
|
| def create_plain_ar_dataset( |
| dataset_path: str, |
| split: str, |
| tokenizer_path: str, |
| num_audio_token: int, |
| ) -> PlainARMusicDataset: |
| hf_ds = datasets.load_from_disk(dataset_path) |
| if isinstance(hf_ds, datasets.DatasetDict): |
| container = hf_ds |
| else: |
| container = {split: hf_ds} |
| return PlainARMusicDataset( |
| datasets_obj=container, |
| split=split, |
| tokenizer_path=tokenizer_path, |
| num_audio_token=num_audio_token, |
| use_fast=True, |
| ) |
|
|
|
|
| def run_train(args: argparse.Namespace) -> None: |
| model_dtype = get_model_dtype(args.model_dtype) |
| model_source = resolve_model_source( |
| model_path=args.model_path, |
| resume_from_checkpoint=args.resume_from_checkpoint, |
| ) |
| num_audio_token = resolve_num_audio_token(model_source, args.num_audio_token) |
| print(f"Using num_audio_token={num_audio_token}") |
|
|
| train_dataset = create_plain_ar_dataset( |
| dataset_path=args.dataset_path, |
| split="train", |
| tokenizer_path=args.tokenizer_path, |
| num_audio_token=num_audio_token, |
| ) |
| print(f"Dataset size: {len(train_dataset)}") |
|
|
| model = create_plain_qwen3_model( |
| model_path=model_source, |
| model_dtype=model_dtype, |
| attn_implementation=args.attn_implementation, |
| target_vocab_size=train_dataset.tokenizer_vocab_size, |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| num_train_epochs=args.num_train_epochs, |
| warmup_steps=args.warmup_steps, |
| max_grad_norm=args.max_grad_norm, |
| logging_steps=args.logging_steps, |
| save_strategy="epoch", |
| dataloader_num_workers=args.dataloader_num_workers, |
| bf16=(args.model_dtype == "bfloat16"), |
| fp16=(args.model_dtype == "float16"), |
| gradient_checkpointing=args.gradient_checkpointing, |
| gradient_checkpointing_kwargs={"use_reentrant": False}, |
| deepspeed=args.deepspeed, |
| remove_unused_columns=False, |
| dataloader_drop_last=True, |
| report_to=args.report_to, |
| logging_dir=None, |
| run_name=args.wandb_run_name, |
| ) |
|
|
| if args.wandb_project and "wandb" in args.report_to: |
| os.environ["WANDB_PROJECT"] = args.wandb_project |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| data_collator=PlainARDataCollator( |
| pad_token_id=train_dataset.pad_token_id, |
| ), |
| ) |
|
|
| if args.resume_from_checkpoint: |
| print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}") |
| else: |
| print("Starting training from current model initialization.") |
|
|
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| final_dir = os.path.join(args.output_dir, "final") |
| trainer.save_model(final_dir) |
| train_dataset.tokenizer.save_pretrained(final_dir) |
| print(f"Training complete. Final model saved to: {final_dir}") |
|
|
|
|
| def sanitize_generated_section( |
| generated_ids: list[int], |
| eos_audio_id: int, |
| audio_start: int, |
| audio_end: int, |
| ) -> list[int]: |
| if not generated_ids: |
| raise RuntimeError("Generation returned no new tokens for the current section.") |
| if generated_ids[-1] != eos_audio_id: |
| raise RuntimeError( |
| "Section generation did not terminate with [EOA]. " |
| "Increase --max_new_tokens_per_section or inspect the checkpoint." |
| ) |
| invalid_ids = [ |
| tid for tid in generated_ids[:-1] if not (audio_start <= tid < audio_end) |
| ] |
| if invalid_ids: |
| preview = invalid_ids[:8] |
| raise RuntimeError( |
| "Section generation produced non-audio tokens before [EOA]: " |
| f"{preview}" |
| ) |
| audio_ids = generated_ids[:-1] |
| return audio_ids |
|
|
|
|
| @torch.inference_mode() |
| def generate_sections_autoregressively( |
| model: Qwen3ForCausalLM, |
| music_ds: PlainARMusicDataset, |
| sample: dict[str, Any], |
| device: torch.device, |
| args: argparse.Namespace, |
| ) -> tuple[list[list[int]], list[dict[str, Any]]]: |
| messages: list[dict[str, str]] = [] |
| sections = music_ds.prepare_sections(sample) |
| section_records: list[dict[str, Any]] = [] |
| use_cache = args.use_cache and not args.no_cache |
|
|
| eos_token_id = music_ds.EOS_AUDIO |
| pad_token_id = ( |
| int(music_ds.tokenizer.eos_token_id) |
| if music_ds.tokenizer.eos_token_id is not None |
| else music_ds.pad_token_id |
| ) |
|
|
| all_section_audio_ids: list[list[int]] = [] |
| for section_idx, section in enumerate(sections): |
| user_content = music_ds.build_section_user_content( |
| sample=sample, |
| section=section, |
| is_first_turn=(section_idx == 0), |
| ) |
| messages.append({"role": "user", "content": user_content}) |
| messages.append({"role": "assistant", "content": SOA_TOKEN}) |
|
|
| chat_inputs = music_ds.tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=False, |
| return_tensors="pt", |
| return_dict=True, |
| **music_ds._chat_template_kwargs, |
| ) |
| input_ids = chat_inputs["input_ids"].to(device) |
| attention_mask = chat_inputs["attention_mask"].to(device) |
|
|
| generated = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| do_sample=not bool(args.greedy), |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| max_new_tokens=int(args.max_new_tokens_per_section), |
| eos_token_id=eos_token_id, |
| pad_token_id=pad_token_id, |
| use_cache=use_cache, |
| ) |
| new_ids = generated[0, input_ids.shape[1] :].tolist() |
| audio_ids = sanitize_generated_section( |
| generated_ids=new_ids, |
| eos_audio_id=music_ds.EOS_AUDIO, |
| audio_start=music_ds.audio_prefix_length, |
| audio_end=music_ds.MASK_AUDIO, |
| ) |
|
|
| all_section_audio_ids.append(audio_ids) |
| messages[-1]["content"] = music_ds.render_audio_token_string( |
| audio_token_ids=audio_ids, |
| include_bos=True, |
| include_eos=True, |
| ) |
|
|
| section_records.append( |
| { |
| "section_index": section_idx, |
| "section_label": music_ds.format_section_label(section), |
| "desc": section.desc, |
| "lyrics": section.text, |
| "generated_audio_tokens": len(audio_ids), |
| } |
| ) |
| print( |
| f"[INFO] section={section_idx} " |
| f"label={music_ds.format_section_label(section)!r} " |
| f"generated_audio_tokens={len(audio_ids)}" |
| ) |
|
|
| return all_section_audio_ids, section_records |
|
|
|
|
| def save_inference_outputs( |
| output_dir: str, |
| output_prefix: str, |
| sample_idx: int, |
| sample: dict[str, Any], |
| section_audio_ids: list[list[int]], |
| section_records: list[dict[str, Any]], |
| music_ds: PlainARMusicDataset, |
| args: argparse.Namespace, |
| ) -> None: |
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
| stamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| prefix = output_prefix or f"{sample['song_id']}_{sample_idx}_{stamp}" |
|
|
| json_path = Path(output_dir) / f"{prefix}.json" |
| wav_path = Path(output_dir) / f"{prefix}.wav" |
|
|
| flat_token_ids: list[int] = [] |
| for section_ids in section_audio_ids: |
| flat_token_ids.extend(section_ids) |
|
|
| payload = { |
| "song_id": str(sample.get("song_id", f"sample_{sample_idx}")), |
| "sample_idx": int(sample_idx), |
| "num_sections": len(section_audio_ids), |
| "generated_audio_tokens": len(flat_token_ids), |
| "sections": section_records, |
| } |
| with open(json_path, "w", encoding="utf-8") as f: |
| import json |
|
|
| json.dump(payload, f, ensure_ascii=False, indent=2) |
| print(f"[OK] {json_path}") |
|
|
| if args.skip_decode: |
| return |
|
|
| shifted_codes = np.asarray(flat_token_ids, dtype=np.int64) - music_ds.audio_prefix_length |
| if shifted_codes.size == 0: |
| print("[WARN] No generated MuCodec tokens; skipping wav decode.") |
| return |
|
|
| import torchaudio |
|
|
| mucodec_decoder = build_mucodec_decoder(args) |
| wave = decode_mucodec_codes( |
| mucodec_decoder=mucodec_decoder, |
| shifted_codes=shifted_codes, |
| args=args, |
| ) |
| torchaudio.save(str(wav_path), wave, int(args.mucodec_sample_rate)) |
| print(f"[OK] {wav_path}") |
|
|
|
|
| def run_infer(args: argparse.Namespace) -> None: |
| seed_everything(args.seed) |
|
|
| tokenizer_path = args.tokenizer_path or args.model_path |
| num_audio_token = resolve_num_audio_token(args.model_path, args.num_audio_token) |
| print(f"Using num_audio_token={num_audio_token}") |
|
|
| device = resolve_device(args.device) |
| dtype = get_model_dtype(args.dtype) |
| if device.type == "cpu" and dtype != torch.float32: |
| print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") |
| dtype = torch.float32 |
|
|
| music_ds = create_plain_ar_dataset( |
| dataset_path=args.dataset_path, |
| split=args.split, |
| tokenizer_path=tokenizer_path, |
| num_audio_token=num_audio_token, |
| ) |
| sample = music_ds.raw_sample(args.sample_idx) |
| model = load_plain_qwen3_for_inference( |
| model_path=args.model_path, |
| device=device, |
| dtype=dtype, |
| attn_implementation=args.attn_implementation, |
| target_vocab_size=music_ds.tokenizer_vocab_size, |
| ) |
|
|
| section_audio_ids, section_records = generate_sections_autoregressively( |
| model=model, |
| music_ds=music_ds, |
| sample=sample, |
| device=device, |
| args=args, |
| ) |
| save_inference_outputs( |
| output_dir=args.output_dir, |
| output_prefix=args.output_prefix, |
| sample_idx=args.sample_idx, |
| sample=sample, |
| section_audio_ids=section_audio_ids, |
| section_records=section_records, |
| music_ds=music_ds, |
| args=args, |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.command == "train": |
| run_train(args) |
| return |
| if args.command == "infer": |
| run_infer(args) |
| return |
| raise ValueError(f"Unknown command: {args.command}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|