#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import json import traceback from pathlib import Path import datasets import torch from inference_full import ( TokenLayout, batch_generate_segmentwise, build_mucodec_decoder, generate_segmentwise, load_hf_template_sample_from_music_dataset, save_outputs, ) from runtime_utils import ( load_magel_checkpoint, load_music_dataset, maybe_compile_model, resolve_device, seed_everything, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run audio inference on validation samples for multiple checkpoints." ) parser.add_argument( "--checkpoint_list", type=str, default=None, help="Text file with one checkpoint path per line.", ) parser.add_argument( "--checkpoint_dir", type=str, default=None, help="Directory to scan for checkpoint-* subdirectories and optional final.", ) parser.add_argument( "--dataset_path", type=str, default="muse_mucodec_chord.ds", ) parser.add_argument( "--split", type=str, default="validation", ) parser.add_argument( "--tokenizer_path", type=str, default="checkpoints/Qwen3-0.6B", ) parser.add_argument( "--sample_indices", type=int, nargs="*", default=None, help="Specific sample indices to infer. Leave unset to run the full split.", ) parser.add_argument( "--max_samples", type=int, default=0, help="Run only the first N samples from the split. Ignored if --sample_indices is set.", ) parser.add_argument( "--infer_batch_size", type=int, default=1, help="Number of samples to decode together per step for the same checkpoint.", ) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.90) parser.add_argument("--greedy", action="store_true", default=False) parser.add_argument("--max_audio_tokens", type=int, default=0) parser.add_argument("--fps", type=int, default=25) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--device", type=str, default="auto") parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], ) parser.add_argument( "--attn_implementation", type=str, default="sdpa", choices=["eager", "sdpa", "flash_attention_2"], ) parser.add_argument("--use_cache", action="store_true", default=True) parser.add_argument("--no_cache", action="store_true", default=False) parser.add_argument("--compile", action="store_true", default=False) parser.add_argument( "--compile_mode", type=str, default="reduce-overhead", choices=["default", "reduce-overhead", "max-autotune"], ) parser.add_argument("--mucodec_device", type=str, default="auto") parser.add_argument("--mucodec_layer_num", type=int, default=7) parser.add_argument("--mucodec_duration", type=float, default=40.96) parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5) parser.add_argument("--mucodec_num_steps", type=int, default=20) parser.add_argument("--mucodec_sample_rate", type=int, default=48000) parser.add_argument( "--output_dir", type=str, default="/root/new_batch_predictions", help="Root output dir. Each checkpoint gets its own subdirectory.", ) parser.add_argument( "--summary_json", type=str, default="/root/new_batch_predictions/summary.json", ) args = parser.parse_args() if not args.checkpoint_list and not args.checkpoint_dir: parser.error("one of --checkpoint_list or --checkpoint_dir is required") return args def parse_checkpoint_list(path: str) -> list[str]: checkpoints: list[str] = [] with open(path, "r", encoding="utf-8") as f: for raw_line in f: line = raw_line.strip() if not line or line.startswith("#"): continue checkpoints.append(line) if not checkpoints: raise ValueError(f"No checkpoints found in list: {path}") return checkpoints def scan_checkpoint_dir(path: str) -> list[str]: root = Path(path) if not root.is_dir(): raise NotADirectoryError(f"Checkpoint directory not found: {path}") checkpoint_dirs = [ item for item in root.iterdir() if item.is_dir() and item.name.startswith("checkpoint-") ] checkpoint_dirs = sorted( checkpoint_dirs, key=lambda p: int(p.name.split("-", 1)[1]) if p.name.split("-", 1)[1].isdigit() else p.name, ) final_dir = root / "final" if final_dir.is_dir(): checkpoint_dirs.append(final_dir) checkpoints = [str(path_obj) for path_obj in checkpoint_dirs] if not checkpoints: raise ValueError(f"No checkpoint-* directories found under: {path}") return checkpoints def get_dtype(name: str) -> torch.dtype: return { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, }[name] def get_split_size(dataset_path: str, split: str) -> int: dataset_obj = datasets.load_from_disk(dataset_path) if isinstance(dataset_obj, datasets.DatasetDict): if split not in dataset_obj: raise KeyError(f"Split not found: {split}") return len(dataset_obj[split]) return len(dataset_obj) def resolve_sample_indices( dataset_path: str, split: str, sample_indices: list[int] | None, max_samples: int, ) -> list[int]: if sample_indices: return list(sample_indices) split_size = get_split_size(dataset_path, split) if max_samples and max_samples > 0: split_size = min(split_size, max_samples) return list(range(split_size)) def sanitize_checkpoint_name(checkpoint_path: str) -> str: path = Path(checkpoint_path.rstrip("/")) if path.parent.name: return f"{path.parent.name}__{path.name}" return path.name def chunk_list(items: list[int], chunk_size: int) -> list[list[int]]: return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] def main() -> None: args = parse_args() seed_everything(args.seed) if args.checkpoint_list: checkpoints = parse_checkpoint_list(args.checkpoint_list) else: checkpoints = scan_checkpoint_dir(args.checkpoint_dir) sample_indices = resolve_sample_indices( dataset_path=args.dataset_path, split=args.split, sample_indices=args.sample_indices, max_samples=args.max_samples, ) use_cache = args.use_cache and not args.no_cache device = resolve_device(args.device) dtype = get_dtype(args.dtype) if device.type == "cpu" and dtype != torch.float32: print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") dtype = torch.float32 output_root = Path(args.output_dir) output_root.mkdir(parents=True, exist_ok=True) print(f"[INFO] checkpoints={len(checkpoints)}") print(f"[INFO] samples_per_checkpoint={len(sample_indices)}") print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}") mucodec_decoder = build_mucodec_decoder(args) summary: list[dict] = [] for checkpoint_path in checkpoints: ckpt_name = sanitize_checkpoint_name(checkpoint_path) ckpt_output_dir = output_root / ckpt_name json_dir = ckpt_output_dir / "json" wav_dir = ckpt_output_dir / "wav" print(f"\n[INFO] loading model from {checkpoint_path}") model = load_magel_checkpoint( checkpoint_path=checkpoint_path, device=device, dtype=dtype, attn_implementation=args.attn_implementation, ) model = maybe_compile_model( model, enabled=bool(args.compile), mode=str(args.compile_mode), ) num_audio_codebook = int(getattr(model.config, "magel_num_audio_token", 16384)) music_ds = load_music_dataset( dataset_path=args.dataset_path, split=args.split, tokenizer_path=args.tokenizer_path, num_audio_token=num_audio_codebook, use_fast=True, ) checkpoint_record = { "checkpoint_path": checkpoint_path, "checkpoint_name": ckpt_name, "status": "ok", "num_samples_requested": len(sample_indices), "results": [], } try: for batch_indices in chunk_list(sample_indices, max(1, int(args.infer_batch_size))): samples = [] for sample_idx in batch_indices: print( f"[INFO] checkpoint={ckpt_name} sample_idx={sample_idx} split={args.split}" ) samples.append( load_hf_template_sample_from_music_dataset( music_ds=music_ds, sample_idx=sample_idx, num_audio_codebook=num_audio_codebook, ) ) layout = TokenLayout( num_text_token=samples[0].num_text_token, num_audio_codebook=num_audio_codebook, ) if len(samples) == 1: batch_outputs = [ generate_segmentwise( model=model, sample=samples[0], layout=layout, device=device, use_cache=use_cache, temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), greedy=bool(args.greedy), max_audio_tokens=max(0, int(args.max_audio_tokens)), ) ] else: try: batch_outputs = batch_generate_segmentwise( model=model, samples=samples, layout=layout, device=device, use_cache=use_cache, temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), greedy=bool(args.greedy), max_audio_tokens=max(0, int(args.max_audio_tokens)), ) except Exception as exc: print( "[WARN] batch_generate_segmentwise failed; " f"falling back to single-sample decode. error={exc!r}" ) traceback.print_exc() batch_outputs = [ generate_segmentwise( model=model, sample=sample, layout=layout, device=device, use_cache=use_cache, temperature=float(args.temperature), top_k=int(args.top_k), top_p=float(args.top_p), greedy=bool(args.greedy), max_audio_tokens=max(0, int(args.max_audio_tokens)), ) for sample in samples ] for sample_idx, sample, batch_output in zip(batch_indices, samples, batch_outputs): generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids = batch_output prefix = f"{sample_idx:05d}_{sample.song_id}" # save_outputs expects these attributes on args. args.sample_idx = sample_idx args.json_output_dir = str(json_dir) args.wav_output_dir = str(wav_dir) save_outputs( output_dir=str(ckpt_output_dir), output_prefix=prefix, sample=sample, layout=layout, generated_ids=generated_ids, sampled_chord_ids=sampled_chord_ids, sampled_segment_ids=sampled_segment_ids, args=args, mucodec_decoder=mucodec_decoder, ) checkpoint_record["results"].append( { "sample_idx": sample_idx, "song_id": sample.song_id, "generated_audio_tokens": sampled_count, "wav_path": str(wav_dir / f"{prefix}.wav"), "json_path": str(json_dir / f"{prefix}.chord_segment.json"), } ) except Exception as exc: checkpoint_record["status"] = "error" checkpoint_record["error"] = str(exc) print(f"[ERROR] checkpoint {checkpoint_path}: {exc!r}") traceback.print_exc() summary.append(checkpoint_record) del model if device.type == "cuda": torch.cuda.empty_cache() summary_path = Path(args.summary_json) summary_path.parent.mkdir(parents=True, exist_ok=True) with open(summary_path, "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"\nSaved summary to: {summary_path}") if __name__ == "__main__": main()