#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Pure Qwen3 autoregressive training/inference for music token generation. This variant keeps: - style - section structure in the prompt - section lyrics/description This variant removes: - frame-level chord conditioning - frame-level structure conditioning - condition encoder / AdaLN injection """ from __future__ import annotations import argparse import os from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any import datasets import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoConfig, AutoTokenizer, Trainer, TrainingArguments from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM from audio_tokens import ( EOA_TOKEN, MASK_AUDIO_TOKEN, SOA_TOKEN, add_audio_special_tokens, audio_id_to_token, ) from dataset import normalize_section_text, SECTION_NAME_MAP, SINGLETON_SECTION_NAMES from inference_full import build_mucodec_decoder, decode_mucodec_codes from runtime_utils import resolve_device, seed_everything from vocab import normalize_structure_label def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Pure Qwen3 autoregressive training/inference without frame-level conditioning." ) subparsers = parser.add_subparsers(dest="command", required=True) train_parser = subparsers.add_parser( "train", help="Train a plain Qwen3 autoregressive model on section prompts and audio tokens.", ) add_train_args(train_parser) infer_parser = subparsers.add_parser( "infer", help="Run section-wise autoregressive inference with a plain Qwen3 checkpoint.", ) add_infer_args(infer_parser) return parser.parse_args() def add_train_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds") parser.add_argument( "--model_path", type=str, default="checkpoints/Qwen3-0.6B", help="Local Qwen3 base checkpoint path.", ) parser.add_argument( "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B", help="Local tokenizer checkpoint path.", ) parser.add_argument( "--num_audio_token", type=int, default=None, help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.", ) parser.add_argument( "--model_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument("--output_dir", type=str, default="./output_qwen3_plain_ar") parser.add_argument("--per_device_train_batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=0.01) parser.add_argument("--num_train_epochs", type=float, default=20) parser.add_argument("--warmup_steps", type=int, default=1000) parser.add_argument("--max_grad_norm", type=float, default=5.0) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="Resume training from a Trainer checkpoint directory such as output_dir/checkpoint-500.", ) parser.add_argument("--dataloader_num_workers", type=int, default=12) parser.add_argument( "--gradient_checkpointing", dest="gradient_checkpointing", action="store_true", ) parser.add_argument( "--deepspeed", type=str, default=None, help="Path to DeepSpeed config. Leave unset to disable DeepSpeed.", ) parser.add_argument("--report_to", type=str, default="wandb") parser.add_argument("--wandb_project", type=str, default="vaultum-qwen3-0p6b") parser.add_argument("--wandb_run_name", type=str, default=None) def add_infer_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--model_path", type=str, required=True) parser.add_argument( "--tokenizer_path", type=str, default=None, help="Tokenizer path. Defaults to --model_path.", ) parser.add_argument("--dataset_path", type=str, default="muse_mucodec_chord.ds") parser.add_argument("--split", type=str, default="validation") parser.add_argument("--sample_idx", type=int, default=0) parser.add_argument( "--num_audio_token", type=int, default=None, help="Audio codebook size. Defaults to checkpoint metadata when available, else 16384.", ) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--device", type=str, default="auto") parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.90) parser.add_argument("--greedy", action="store_true", default=False) parser.add_argument("--use_cache", action="store_true", default=True) parser.add_argument("--no_cache", action="store_true", default=False) parser.add_argument( "--max_new_tokens_per_section", type=int, default=2048, help="Upper bound for each section decode before forcing a failure.", ) parser.add_argument("--output_dir", type=str, default="plain_ar_predictions") parser.add_argument("--output_prefix", type=str, default="") parser.add_argument("--skip_decode", action="store_true", default=False) parser.add_argument("--mucodec_device", type=str, default="auto") parser.add_argument("--mucodec_layer_num", type=int, default=7) parser.add_argument("--mucodec_duration", type=float, default=40.96) parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5) parser.add_argument("--mucodec_num_steps", type=int, default=20) parser.add_argument("--mucodec_sample_rate", type=int, default=48000) def resolve_model_source(model_path: str, resume_from_checkpoint: str | None) -> str: if not resume_from_checkpoint: return model_path if os.path.abspath(model_path) != os.path.abspath(resume_from_checkpoint): print( "Ignoring --model_path during resume and loading config/model from: " f"{resume_from_checkpoint}" ) return resume_from_checkpoint def get_model_dtype(name: str) -> torch.dtype: return { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[name] def resolve_num_audio_token(checkpoint_path: str, explicit_value: int | None) -> int: if explicit_value is not None: return int(explicit_value) config = AutoConfig.from_pretrained( checkpoint_path, local_files_only=True, ) return int(getattr(config, "magel_num_audio_token", 16384)) @dataclass class PreparedSection: text: str desc: str start_frame: int end_frame: int structure: str index: int class PlainARDataCollator: def __init__(self, pad_token_id: int = 0): self.pad_token_id = int(pad_token_id) def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: return { "input_ids": pad_sequence( [row["input_ids"] for row in batch], batch_first=True, padding_value=self.pad_token_id, ), "attention_mask": pad_sequence( [row["attention_mask"] for row in batch], batch_first=True, padding_value=0, ), "labels": pad_sequence( [row["labels"] for row in batch], batch_first=True, padding_value=-100, ), } class PlainARMusicDataset(torch.utils.data.Dataset): def __init__( self, datasets_obj, split: str, tokenizer_path: str, num_audio_token: int = 16384, fps: int = 25, use_fast: bool = True, ): self._data = datasets_obj[split] self.tokenizer_path = tokenizer_path self.num_audio_token = int(num_audio_token) self.fps = int(fps) self.use_fast = bool(use_fast) self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=True, use_fast=use_fast, ) add_audio_special_tokens(self.tokenizer, self.num_audio_token) self.tokenizer_vocab_size = len(self.tokenizer) self.audio_prefix_length = int( self.tokenizer.convert_tokens_to_ids(audio_id_to_token(0)) ) self.MASK_AUDIO = int(self.tokenizer.convert_tokens_to_ids(MASK_AUDIO_TOKEN)) self.BOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(SOA_TOKEN)) self.EOS_AUDIO = int(self.tokenizer.convert_tokens_to_ids(EOA_TOKEN)) self.pad_token_id = ( int(self.tokenizer.pad_token_id) if self.tokenizer.pad_token_id is not None else 0 ) self._assistant_audio_placeholder = f"{SOA_TOKEN}{EOA_TOKEN}" self._chat_template_kwargs = {"enable_thinking": False} def __len__(self) -> int: return len(self._data) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: sample = self.raw_sample(idx) sections = self.prepare_sections(sample) token_ids, attention_mask = self.tokenize_messages( self.build_messages(sample, sections), sample["mucodec_codes"], sections, ) labels = self.build_labels(token_ids) return { "input_ids": token_ids, "attention_mask": attention_mask, "labels": labels, } def raw_sample(self, idx: int) -> dict[str, Any]: return self._data[idx] @staticmethod def _positions(token_ids: torch.Tensor, target_id: int) -> torch.Tensor: return torch.nonzero(token_ids == target_id, as_tuple=False).squeeze(-1) @staticmethod def _sorted_sections(sample: dict[str, Any]) -> list[dict[str, Any]]: return sorted( ( { "raw_index": raw_index, "text": str(seg["text"]), "desc": str(seg["desc"]).strip(), "start": float(seg["start"]), "end": float(seg["end"]), "structure": normalize_structure_label(seg["section"]), } for raw_index, seg in enumerate(sample.get("sections", [])) ), key=lambda seg: (seg["start"], seg["end"], seg["raw_index"]), ) def prepare_sections(self, sample: dict[str, Any]) -> list[PreparedSection]: sections: list[PreparedSection] = [] section_counts: dict[str, int] = {} total_frames = len(sample["mucodec_codes"]) prev_end_idx = 0 sample_language = sample.get("language") for seg in self._sorted_sections(sample): structure = seg["structure"] section_counts[structure] = section_counts.get(structure, 0) + 1 raw_end_idx = max( prev_end_idx, min(total_frames, int(np.ceil(seg["end"] * self.fps))), ) sections.append( PreparedSection( text=normalize_section_text( seg["text"], structure, language=sample_language, ), desc=seg["desc"], start_frame=prev_end_idx, end_frame=raw_end_idx, structure=structure, index=section_counts[structure], ) ) prev_end_idx = raw_end_idx if sections: last = sections[-1] sections[-1] = PreparedSection( text=last.text, desc=last.desc, start_frame=last.start_frame, end_frame=total_frames, structure=last.structure, index=last.index, ) return sections def format_section_label(self, section: PreparedSection) -> str: label = SECTION_NAME_MAP[section.structure] if section.structure in SINGLETON_SECTION_NAMES and section.index == 1: return label return f"{label} {section.index}" def build_section_user_content( self, sample: dict[str, Any], section: PreparedSection, is_first_turn: bool, ) -> str: parts: list[str] = [] if is_first_turn: style = str(sample.get("style", "")).strip() if style: parts.append( f"Please generate a song in the following style:{style}\n" "Next, I will tell you the requirements and lyrics for the song " "fragment to be generated, section by section." ) else: parts.append( "Please generate the song section by section. " "Next, I will tell you the requirements and lyrics for each fragment." ) section_parts = [f"[{self.format_section_label(section)}]"] if section.desc: section_parts.append(f"[desc:{section.desc}]") if section.text: section_parts.append(f"[lyrics:{section.text}]") parts.append("".join(section_parts)) return "\n".join(part for part in parts if part) def build_messages( self, sample: dict[str, Any], sections: list[PreparedSection], ) -> list[dict[str, str]]: messages: list[dict[str, str]] = [] for idx, section in enumerate(sections): messages.append( { "role": "user", "content": self.build_section_user_content( sample=sample, section=section, is_first_turn=(idx == 0), ), } ) messages.append( { "role": "assistant", "content": self._assistant_audio_placeholder, } ) return messages def tokenize_messages( self, messages: list[dict[str, str]], full_audio_codes, sections: list[PreparedSection], ) -> tuple[torch.Tensor, torch.Tensor]: chat_inputs = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, return_tensors="pt", return_dict=True, **self._chat_template_kwargs, ) token_ids = chat_inputs["input_ids"].squeeze(0).to(torch.long) attention_mask = chat_inputs["attention_mask"].squeeze(0).to(torch.long) return self.expand_audio_tokens( token_ids=token_ids, attention_mask=attention_mask, full_audio_codes=full_audio_codes, sections=sections, ) def expand_audio_tokens( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, full_audio_codes, sections: list[PreparedSection], ) -> tuple[torch.Tensor, torch.Tensor]: if not sections: return token_ids, attention_mask bos_positions = self._positions(token_ids, self.BOS_AUDIO) eos_positions = self._positions(token_ids, self.EOS_AUDIO) audio_code_tensor = torch.as_tensor(full_audio_codes, dtype=torch.long) extra_audio_tokens = sum( int(section.end_frame) - int(section.start_frame) for section in sections ) final_len = token_ids.numel() + extra_audio_tokens expanded_token_ids = torch.empty(final_len, dtype=torch.long) expanded_attention_mask = torch.empty(final_len, dtype=torch.long) read_pos = 0 write_pos = 0 for bos_pos, eos_pos, section in zip( bos_positions.tolist(), eos_positions.tolist(), sections, ): start_idx = int(section.start_frame) end_idx = int(section.end_frame) audio_len = max(0, end_idx - start_idx) prefix_len = bos_pos + 1 - read_pos next_write = write_pos + prefix_len expanded_token_ids[write_pos:next_write] = token_ids[read_pos : bos_pos + 1] expanded_attention_mask[write_pos:next_write] = attention_mask[ read_pos : bos_pos + 1 ] write_pos = next_write if audio_len > 0: next_write = write_pos + audio_len expanded_token_ids[write_pos:next_write] = audio_code_tensor[ start_idx:end_idx ] expanded_token_ids[write_pos:next_write].add_(self.audio_prefix_length) expanded_attention_mask[write_pos:next_write] = 1 write_pos = next_write expanded_token_ids[write_pos] = token_ids[eos_pos] expanded_attention_mask[write_pos] = attention_mask[eos_pos] write_pos += 1 read_pos = eos_pos + 1 tail_len = token_ids.numel() - read_pos if tail_len > 0: expanded_token_ids[write_pos : write_pos + tail_len] = token_ids[read_pos:] expanded_attention_mask[write_pos : write_pos + tail_len] = attention_mask[ read_pos: ] return expanded_token_ids, expanded_attention_mask def build_labels(self, token_ids: torch.Tensor) -> torch.Tensor: audio_codebook_mask = (token_ids >= self.audio_prefix_length) & ( token_ids < self.MASK_AUDIO ) eos_mask = token_ids == self.EOS_AUDIO label_mask = audio_codebook_mask | eos_mask labels = token_ids.clone() labels[~label_mask] = -100 return labels def render_audio_token_string( self, audio_token_ids: list[int], include_bos: bool = True, include_eos: bool = True, ) -> str: parts: list[str] = [] if include_bos: parts.append(SOA_TOKEN) parts.extend(self.tokenizer.convert_ids_to_tokens(audio_token_ids)) if include_eos: parts.append(EOA_TOKEN) return "".join(parts) def create_plain_qwen3_model( model_path: str, model_dtype: torch.dtype, attn_implementation: str, target_vocab_size: int, ) -> Qwen3ForCausalLM: print(f"Loading Qwen3 model from: {model_path}") config = AutoConfig.from_pretrained( model_path, local_files_only=True, ) model = Qwen3ForCausalLM.from_pretrained( model_path, config=config, torch_dtype=model_dtype, attn_implementation=attn_implementation, ignore_mismatched_sizes=True, local_files_only=True, ) model.resize_token_embeddings(target_vocab_size) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") return model def load_plain_qwen3_for_inference( model_path: str, device: torch.device, dtype: torch.dtype, attn_implementation: str, target_vocab_size: int, ) -> Qwen3ForCausalLM: model = create_plain_qwen3_model( model_path=model_path, model_dtype=dtype, attn_implementation=attn_implementation, target_vocab_size=target_vocab_size, ) model.to(device=device) model.eval() return model def create_plain_ar_dataset( dataset_path: str, split: str, tokenizer_path: str, num_audio_token: int, ) -> PlainARMusicDataset: hf_ds = datasets.load_from_disk(dataset_path) if isinstance(hf_ds, datasets.DatasetDict): container = hf_ds else: container = {split: hf_ds} return PlainARMusicDataset( datasets_obj=container, split=split, tokenizer_path=tokenizer_path, num_audio_token=num_audio_token, use_fast=True, ) def run_train(args: argparse.Namespace) -> None: model_dtype = get_model_dtype(args.model_dtype) model_source = resolve_model_source( model_path=args.model_path, resume_from_checkpoint=args.resume_from_checkpoint, ) num_audio_token = resolve_num_audio_token(model_source, args.num_audio_token) print(f"Using num_audio_token={num_audio_token}") train_dataset = create_plain_ar_dataset( dataset_path=args.dataset_path, split="train", tokenizer_path=args.tokenizer_path, num_audio_token=num_audio_token, ) print(f"Dataset size: {len(train_dataset)}") model = create_plain_qwen3_model( model_path=model_source, model_dtype=model_dtype, attn_implementation=args.attn_implementation, target_vocab_size=train_dataset.tokenizer_vocab_size, ) training_args = TrainingArguments( output_dir=args.output_dir, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, weight_decay=args.weight_decay, num_train_epochs=args.num_train_epochs, warmup_steps=args.warmup_steps, max_grad_norm=args.max_grad_norm, logging_steps=args.logging_steps, save_strategy="epoch", dataloader_num_workers=args.dataloader_num_workers, bf16=(args.model_dtype == "bfloat16"), fp16=(args.model_dtype == "float16"), gradient_checkpointing=args.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, deepspeed=args.deepspeed, remove_unused_columns=False, dataloader_drop_last=True, report_to=args.report_to, logging_dir=None, run_name=args.wandb_run_name, ) if args.wandb_project and "wandb" in args.report_to: os.environ["WANDB_PROJECT"] = args.wandb_project trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=PlainARDataCollator( pad_token_id=train_dataset.pad_token_id, ), ) if args.resume_from_checkpoint: print(f"Resuming training from checkpoint: {args.resume_from_checkpoint}") else: print("Starting training from current model initialization.") trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) final_dir = os.path.join(args.output_dir, "final") trainer.save_model(final_dir) train_dataset.tokenizer.save_pretrained(final_dir) print(f"Training complete. Final model saved to: {final_dir}") def sanitize_generated_section( generated_ids: list[int], eos_audio_id: int, audio_start: int, audio_end: int, ) -> list[int]: if not generated_ids: raise RuntimeError("Generation returned no new tokens for the current section.") if generated_ids[-1] != eos_audio_id: raise RuntimeError( "Section generation did not terminate with [EOA]. " "Increase --max_new_tokens_per_section or inspect the checkpoint." ) invalid_ids = [ tid for tid in generated_ids[:-1] if not (audio_start <= tid < audio_end) ] if invalid_ids: preview = invalid_ids[:8] raise RuntimeError( "Section generation produced non-audio tokens before [EOA]: " f"{preview}" ) audio_ids = generated_ids[:-1] return audio_ids @torch.inference_mode() def generate_sections_autoregressively( model: Qwen3ForCausalLM, music_ds: PlainARMusicDataset, sample: dict[str, Any], device: torch.device, args: argparse.Namespace, ) -> tuple[list[list[int]], list[dict[str, Any]]]: messages: list[dict[str, str]] = [] sections = music_ds.prepare_sections(sample) section_records: list[dict[str, Any]] = [] use_cache = args.use_cache and not args.no_cache eos_token_id = music_ds.EOS_AUDIO pad_token_id = ( int(music_ds.tokenizer.eos_token_id) if music_ds.tokenizer.eos_token_id is not None else music_ds.pad_token_id ) all_section_audio_ids: list[list[int]] = [] for section_idx, section in enumerate(sections): user_content = music_ds.build_section_user_content( sample=sample, section=section, is_first_turn=(section_idx == 0), ) messages.append({"role": "user", "content": user_content}) messages.append({"role": "assistant", "content": SOA_TOKEN}) chat_inputs = music_ds.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, return_tensors="pt", return_dict=True, **music_ds._chat_template_kwargs, ) input_ids = chat_inputs["input_ids"].to(device) attention_mask = chat_inputs["attention_mask"].to(device) generated = model.generate( input_ids=input_ids, attention_mask=attention_mask, do_sample=not bool(args.greedy), temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), max_new_tokens=int(args.max_new_tokens_per_section), eos_token_id=eos_token_id, pad_token_id=pad_token_id, use_cache=use_cache, ) new_ids = generated[0, input_ids.shape[1] :].tolist() audio_ids = sanitize_generated_section( generated_ids=new_ids, eos_audio_id=music_ds.EOS_AUDIO, audio_start=music_ds.audio_prefix_length, audio_end=music_ds.MASK_AUDIO, ) all_section_audio_ids.append(audio_ids) messages[-1]["content"] = music_ds.render_audio_token_string( audio_token_ids=audio_ids, include_bos=True, include_eos=True, ) section_records.append( { "section_index": section_idx, "section_label": music_ds.format_section_label(section), "desc": section.desc, "lyrics": section.text, "generated_audio_tokens": len(audio_ids), } ) print( f"[INFO] section={section_idx} " f"label={music_ds.format_section_label(section)!r} " f"generated_audio_tokens={len(audio_ids)}" ) return all_section_audio_ids, section_records def save_inference_outputs( output_dir: str, output_prefix: str, sample_idx: int, sample: dict[str, Any], section_audio_ids: list[list[int]], section_records: list[dict[str, Any]], music_ds: PlainARMusicDataset, args: argparse.Namespace, ) -> None: Path(output_dir).mkdir(parents=True, exist_ok=True) stamp = datetime.now().strftime("%Y%m%d_%H%M%S") prefix = output_prefix or f"{sample['song_id']}_{sample_idx}_{stamp}" json_path = Path(output_dir) / f"{prefix}.json" wav_path = Path(output_dir) / f"{prefix}.wav" flat_token_ids: list[int] = [] for section_ids in section_audio_ids: flat_token_ids.extend(section_ids) payload = { "song_id": str(sample.get("song_id", f"sample_{sample_idx}")), "sample_idx": int(sample_idx), "num_sections": len(section_audio_ids), "generated_audio_tokens": len(flat_token_ids), "sections": section_records, } with open(json_path, "w", encoding="utf-8") as f: import json json.dump(payload, f, ensure_ascii=False, indent=2) print(f"[OK] {json_path}") if args.skip_decode: return shifted_codes = np.asarray(flat_token_ids, dtype=np.int64) - music_ds.audio_prefix_length if shifted_codes.size == 0: print("[WARN] No generated MuCodec tokens; skipping wav decode.") return import torchaudio mucodec_decoder = build_mucodec_decoder(args) wave = decode_mucodec_codes( mucodec_decoder=mucodec_decoder, shifted_codes=shifted_codes, args=args, ) torchaudio.save(str(wav_path), wave, int(args.mucodec_sample_rate)) print(f"[OK] {wav_path}") def run_infer(args: argparse.Namespace) -> None: seed_everything(args.seed) tokenizer_path = args.tokenizer_path or args.model_path num_audio_token = resolve_num_audio_token(args.model_path, args.num_audio_token) print(f"Using num_audio_token={num_audio_token}") device = resolve_device(args.device) dtype = get_model_dtype(args.dtype) if device.type == "cpu" and dtype != torch.float32: print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") dtype = torch.float32 music_ds = create_plain_ar_dataset( dataset_path=args.dataset_path, split=args.split, tokenizer_path=tokenizer_path, num_audio_token=num_audio_token, ) sample = music_ds.raw_sample(args.sample_idx) model = load_plain_qwen3_for_inference( model_path=args.model_path, device=device, dtype=dtype, attn_implementation=args.attn_implementation, target_vocab_size=music_ds.tokenizer_vocab_size, ) section_audio_ids, section_records = generate_sections_autoregressively( model=model, music_ds=music_ds, sample=sample, device=device, args=args, ) save_inference_outputs( output_dir=args.output_dir, output_prefix=args.output_prefix, sample_idx=args.sample_idx, sample=sample, section_audio_ids=section_audio_ids, section_records=section_records, music_ds=music_ds, args=args, ) def main() -> None: args = parse_args() if args.command == "train": run_train(args) return if args.command == "infer": run_infer(args) return raise ValueError(f"Unknown command: {args.command}") if __name__ == "__main__": main()