cond_gen / batch_infer_checkpoints.py
Leon299's picture
Add files using upload-large-folder tool
8337fa0 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import traceback
from pathlib import Path
import datasets
import torch
from inference_full import (
TokenLayout,
batch_generate_segmentwise,
build_mucodec_decoder,
generate_segmentwise,
load_hf_template_sample_from_music_dataset,
save_outputs,
)
from runtime_utils import (
load_magel_checkpoint,
load_music_dataset,
maybe_compile_model,
resolve_device,
seed_everything,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run audio inference on validation samples for multiple checkpoints."
)
parser.add_argument(
"--checkpoint_list",
type=str,
default=None,
help="Text file with one checkpoint path per line.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default=None,
help="Directory to scan for checkpoint-* subdirectories and optional final.",
)
parser.add_argument(
"--dataset_path",
type=str,
default="muse_mucodec_chord.ds",
)
parser.add_argument(
"--split",
type=str,
default="validation",
)
parser.add_argument(
"--tokenizer_path",
type=str,
default="checkpoints/Qwen3-0.6B",
)
parser.add_argument(
"--sample_indices",
type=int,
nargs="*",
default=None,
help="Specific sample indices to infer. Leave unset to run the full split.",
)
parser.add_argument(
"--max_samples",
type=int,
default=0,
help="Run only the first N samples from the split. Ignored if --sample_indices is set.",
)
parser.add_argument(
"--infer_batch_size",
type=int,
default=1,
help="Number of samples to decode together per step for the same checkpoint.",
)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.90)
parser.add_argument("--greedy", action="store_true", default=False)
parser.add_argument("--max_audio_tokens", type=int, default=0)
parser.add_argument("--fps", type=int, default=25)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
)
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["eager", "sdpa", "flash_attention_2"],
)
parser.add_argument("--use_cache", action="store_true", default=True)
parser.add_argument("--no_cache", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument(
"--compile_mode",
type=str,
default="reduce-overhead",
choices=["default", "reduce-overhead", "max-autotune"],
)
parser.add_argument("--mucodec_device", type=str, default="auto")
parser.add_argument("--mucodec_layer_num", type=int, default=7)
parser.add_argument("--mucodec_duration", type=float, default=40.96)
parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5)
parser.add_argument("--mucodec_num_steps", type=int, default=20)
parser.add_argument("--mucodec_sample_rate", type=int, default=48000)
parser.add_argument(
"--output_dir",
type=str,
default="/root/new_batch_predictions",
help="Root output dir. Each checkpoint gets its own subdirectory.",
)
parser.add_argument(
"--summary_json",
type=str,
default="/root/new_batch_predictions/summary.json",
)
args = parser.parse_args()
if not args.checkpoint_list and not args.checkpoint_dir:
parser.error("one of --checkpoint_list or --checkpoint_dir is required")
return args
def parse_checkpoint_list(path: str) -> list[str]:
checkpoints: list[str] = []
with open(path, "r", encoding="utf-8") as f:
for raw_line in f:
line = raw_line.strip()
if not line or line.startswith("#"):
continue
checkpoints.append(line)
if not checkpoints:
raise ValueError(f"No checkpoints found in list: {path}")
return checkpoints
def scan_checkpoint_dir(path: str) -> list[str]:
root = Path(path)
if not root.is_dir():
raise NotADirectoryError(f"Checkpoint directory not found: {path}")
checkpoint_dirs = [
item
for item in root.iterdir()
if item.is_dir() and item.name.startswith("checkpoint-")
]
checkpoint_dirs = sorted(
checkpoint_dirs,
key=lambda p: int(p.name.split("-", 1)[1])
if p.name.split("-", 1)[1].isdigit()
else p.name,
)
final_dir = root / "final"
if final_dir.is_dir():
checkpoint_dirs.append(final_dir)
checkpoints = [str(path_obj) for path_obj in checkpoint_dirs]
if not checkpoints:
raise ValueError(f"No checkpoint-* directories found under: {path}")
return checkpoints
def get_dtype(name: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[name]
def get_split_size(dataset_path: str, split: str) -> int:
dataset_obj = datasets.load_from_disk(dataset_path)
if isinstance(dataset_obj, datasets.DatasetDict):
if split not in dataset_obj:
raise KeyError(f"Split not found: {split}")
return len(dataset_obj[split])
return len(dataset_obj)
def resolve_sample_indices(
dataset_path: str,
split: str,
sample_indices: list[int] | None,
max_samples: int,
) -> list[int]:
if sample_indices:
return list(sample_indices)
split_size = get_split_size(dataset_path, split)
if max_samples and max_samples > 0:
split_size = min(split_size, max_samples)
return list(range(split_size))
def sanitize_checkpoint_name(checkpoint_path: str) -> str:
path = Path(checkpoint_path.rstrip("/"))
if path.parent.name:
return f"{path.parent.name}__{path.name}"
return path.name
def chunk_list(items: list[int], chunk_size: int) -> list[list[int]]:
return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)]
def main() -> None:
args = parse_args()
seed_everything(args.seed)
if args.checkpoint_list:
checkpoints = parse_checkpoint_list(args.checkpoint_list)
else:
checkpoints = scan_checkpoint_dir(args.checkpoint_dir)
sample_indices = resolve_sample_indices(
dataset_path=args.dataset_path,
split=args.split,
sample_indices=args.sample_indices,
max_samples=args.max_samples,
)
use_cache = args.use_cache and not args.no_cache
device = resolve_device(args.device)
dtype = get_dtype(args.dtype)
if device.type == "cpu" and dtype != torch.float32:
print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.")
dtype = torch.float32
output_root = Path(args.output_dir)
output_root.mkdir(parents=True, exist_ok=True)
print(f"[INFO] checkpoints={len(checkpoints)}")
print(f"[INFO] samples_per_checkpoint={len(sample_indices)}")
print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}")
mucodec_decoder = build_mucodec_decoder(args)
summary: list[dict] = []
for checkpoint_path in checkpoints:
ckpt_name = sanitize_checkpoint_name(checkpoint_path)
ckpt_output_dir = output_root / ckpt_name
json_dir = ckpt_output_dir / "json"
wav_dir = ckpt_output_dir / "wav"
print(f"\n[INFO] loading model from {checkpoint_path}")
model = load_magel_checkpoint(
checkpoint_path=checkpoint_path,
device=device,
dtype=dtype,
attn_implementation=args.attn_implementation,
)
model = maybe_compile_model(
model,
enabled=bool(args.compile),
mode=str(args.compile_mode),
)
num_audio_codebook = int(getattr(model.config, "magel_num_audio_token", 16384))
music_ds = load_music_dataset(
dataset_path=args.dataset_path,
split=args.split,
tokenizer_path=args.tokenizer_path,
num_audio_token=num_audio_codebook,
use_fast=True,
)
checkpoint_record = {
"checkpoint_path": checkpoint_path,
"checkpoint_name": ckpt_name,
"status": "ok",
"num_samples_requested": len(sample_indices),
"results": [],
}
try:
for batch_indices in chunk_list(sample_indices, max(1, int(args.infer_batch_size))):
samples = []
for sample_idx in batch_indices:
print(
f"[INFO] checkpoint={ckpt_name} sample_idx={sample_idx} split={args.split}"
)
samples.append(
load_hf_template_sample_from_music_dataset(
music_ds=music_ds,
sample_idx=sample_idx,
num_audio_codebook=num_audio_codebook,
)
)
layout = TokenLayout(
num_text_token=samples[0].num_text_token,
num_audio_codebook=num_audio_codebook,
)
if len(samples) == 1:
batch_outputs = [
generate_segmentwise(
model=model,
sample=samples[0],
layout=layout,
device=device,
use_cache=use_cache,
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
greedy=bool(args.greedy),
max_audio_tokens=max(0, int(args.max_audio_tokens)),
)
]
else:
try:
batch_outputs = batch_generate_segmentwise(
model=model,
samples=samples,
layout=layout,
device=device,
use_cache=use_cache,
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
greedy=bool(args.greedy),
max_audio_tokens=max(0, int(args.max_audio_tokens)),
)
except Exception as exc:
print(
"[WARN] batch_generate_segmentwise failed; "
f"falling back to single-sample decode. error={exc!r}"
)
traceback.print_exc()
batch_outputs = [
generate_segmentwise(
model=model,
sample=sample,
layout=layout,
device=device,
use_cache=use_cache,
temperature=float(args.temperature),
top_k=int(args.top_k),
top_p=float(args.top_p),
greedy=bool(args.greedy),
max_audio_tokens=max(0, int(args.max_audio_tokens)),
)
for sample in samples
]
for sample_idx, sample, batch_output in zip(batch_indices, samples, batch_outputs):
generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids = batch_output
prefix = f"{sample_idx:05d}_{sample.song_id}"
# save_outputs expects these attributes on args.
args.sample_idx = sample_idx
args.json_output_dir = str(json_dir)
args.wav_output_dir = str(wav_dir)
save_outputs(
output_dir=str(ckpt_output_dir),
output_prefix=prefix,
sample=sample,
layout=layout,
generated_ids=generated_ids,
sampled_chord_ids=sampled_chord_ids,
sampled_segment_ids=sampled_segment_ids,
args=args,
mucodec_decoder=mucodec_decoder,
)
checkpoint_record["results"].append(
{
"sample_idx": sample_idx,
"song_id": sample.song_id,
"generated_audio_tokens": sampled_count,
"wav_path": str(wav_dir / f"{prefix}.wav"),
"json_path": str(json_dir / f"{prefix}.chord_segment.json"),
}
)
except Exception as exc:
checkpoint_record["status"] = "error"
checkpoint_record["error"] = str(exc)
print(f"[ERROR] checkpoint {checkpoint_path}: {exc!r}")
traceback.print_exc()
summary.append(checkpoint_record)
del model
if device.type == "cuda":
torch.cuda.empty_cache()
summary_path = Path(args.summary_json)
summary_path.parent.mkdir(parents=True, exist_ok=True)
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\nSaved summary to: {summary_path}")
if __name__ == "__main__":
main()