| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import os |
| import random |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import yaml |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from bandtok import BandTokPipeline |
| from bandtok.audio_utils import save_audio |
| from bandtok.model import _autocast, _dtype_from_config |
|
|
|
|
| @dataclass |
| class InferenceItem: |
| name: str |
| caption: str |
| seconds_start: float |
| seconds_total: float |
| duration: float |
|
|
|
|
| def is_dist_ready() -> bool: |
| return dist.is_available() and dist.is_initialized() |
|
|
|
|
| def get_rank() -> int: |
| return dist.get_rank() if is_dist_ready() else 0 |
|
|
|
|
| def get_world_size() -> int: |
| return dist.get_world_size() if is_dist_ready() else 1 |
|
|
|
|
| def rank0_print(*args: Any, **kwargs: Any) -> None: |
| if get_rank() == 0: |
| print(*args, **kwargs) |
|
|
|
|
| def init_distributed() -> int: |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 and not is_dist_ready(): |
| dist.init_process_group(backend="nccl", init_method="env://") |
| return local_rank |
|
|
|
|
| def cleanup_distributed(local_rank: int) -> None: |
| if is_dist_ready(): |
| dist.barrier(device_ids=[local_rank] if torch.cuda.is_available() else None) |
| dist.destroy_process_group() |
|
|
|
|
| def chunk_list(items: list[InferenceItem], chunk_size: int): |
| for i in range(0, len(items), chunk_size): |
| yield items[i : i + chunk_size] |
|
|
|
|
| def flatten_mapping(data: dict[str, Any], parent: str = "") -> dict[str, dict[str, Any]]: |
| out: dict[str, dict[str, Any]] = {} |
| for key, value in data.items(): |
| name = f"{parent}/{key}" if parent else str(key) |
| if isinstance(value, dict) and any(k in value for k in ("caption", "prompt", "text")): |
| out[name] = value |
| elif isinstance(value, dict): |
| out.update(flatten_mapping(value, name)) |
| else: |
| out[name] = {"caption": str(value)} |
| return out |
|
|
|
|
| def safe_rel_name(name: str, index: int) -> str: |
| name = str(name).replace("\\", "/").strip("/") |
| parts = [part for part in name.split("/") if part and part not in (".", "..")] |
| if not parts: |
| parts = [f"{index:06d}"] |
| return "/".join(parts) |
|
|
|
|
| def load_raw_items(config_path: Path) -> list[tuple[str, dict[str, Any]]]: |
| suffix = config_path.suffix.lower() |
| if suffix in (".yaml", ".yml"): |
| with open(config_path, "r", encoding="utf-8") as f: |
| data = yaml.safe_load(f) |
| if isinstance(data, dict): |
| return list(flatten_mapping(data).items()) |
| if isinstance(data, list): |
| return [(str(item.get("id", i)) if isinstance(item, dict) else str(i), item if isinstance(item, dict) else {"caption": str(item)}) for i, item in enumerate(data)] |
| raise ValueError(f"Unsupported YAML root type: {type(data)!r}") |
|
|
| if suffix == ".json": |
| with open(config_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| if isinstance(data, dict): |
| return list(flatten_mapping(data).items()) |
| if isinstance(data, list): |
| return [(str(item.get("id", i)) if isinstance(item, dict) else str(i), item if isinstance(item, dict) else {"caption": str(item)}) for i, item in enumerate(data)] |
| raise ValueError(f"Unsupported JSON root type: {type(data)!r}") |
|
|
| if suffix == ".jsonl": |
| rows = [] |
| with open(config_path, "r", encoding="utf-8") as f: |
| for i, line in enumerate(f): |
| line = line.strip() |
| if not line: |
| continue |
| item = json.loads(line) |
| if not isinstance(item, dict): |
| item = {"caption": str(item)} |
| rows.append((str(item.get("id", i)), item)) |
| return rows |
|
|
| if suffix == ".csv": |
| with open(config_path, "r", encoding="utf-8", newline="") as f: |
| reader = csv.DictReader(f) |
| return [(str(row.get("id") or row.get("name") or i), dict(row)) for i, row in enumerate(reader)] |
|
|
| with open(config_path, "r", encoding="utf-8") as f: |
| return [(f"{i:06d}", {"caption": line.strip()}) for i, line in enumerate(f) if line.strip()] |
|
|
|
|
| def parse_items(args: argparse.Namespace) -> list[InferenceItem]: |
| raw_items = load_raw_items(Path(args.test_config).expanduser()) |
| items = [] |
| for index, (name, data) in enumerate(raw_items): |
| caption = data.get("caption") or data.get("prompt") or data.get("text") |
| if not caption: |
| raise ValueError(f"Missing caption/prompt/text for item {name!r}") |
| seconds_start = args.second_start if args.second_start is not None else float(data.get("seconds_start", 10.0)) |
| seconds_total = args.second_total if args.second_total is not None else float(data.get("seconds_total", 40.0)) |
| duration = args.duration if args.duration is not None else float(data.get("duration", 10.0)) |
| items.append( |
| InferenceItem( |
| name=safe_rel_name(data.get("output") or data.get("path") or name, index), |
| caption=str(caption), |
| seconds_start=float(seconds_start), |
| seconds_total=float(seconds_total), |
| duration=float(duration), |
| ) |
| ) |
| if args.max_items > 0: |
| items = items[: args.max_items] |
| expanded = [] |
| for item in items: |
| for sample_idx in range(args.n_sample_per_cond): |
| if args.n_sample_per_cond == 1: |
| expanded.append(item) |
| else: |
| expanded.append( |
| InferenceItem( |
| name=f"{item.name}_sample{sample_idx:02d}", |
| caption=item.caption, |
| seconds_start=item.seconds_start, |
| seconds_total=item.seconds_total, |
| duration=item.duration, |
| ) |
| ) |
| return expanded |
|
|
|
|
| def set_seed(seed: int, rank: int) -> None: |
| if seed < 0: |
| return |
| seed = seed + rank |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def normalize_batch_audio(audio: torch.Tensor, batch_size: int) -> torch.Tensor: |
| if audio.ndim == 1: |
| audio = audio.view(1, 1, -1) |
| elif audio.ndim == 2: |
| audio = audio.unsqueeze(0) if batch_size == 1 else audio.unsqueeze(1) |
| elif audio.ndim != 3: |
| raise ValueError(f"Expected generated audio with 1-3 dims, got {tuple(audio.shape)}") |
| return audio |
|
|
|
|
| @torch.inference_mode() |
| def generate_batch(pipe: BandTokPipeline, items: list[InferenceItem], args: argparse.Namespace) -> torch.Tensor: |
| defaults = pipe.config.get("generation", {}) |
| tokens_per_second = pipe.sample_rate / float(pipe.model.pretransform.downsampling_ratio) * float(pipe.model.pretransform.num_quantizers) |
| max_duration = max(item.duration for item in items) |
| max_gen_len = args.max_gen_len or max(1, round(max_duration * tokens_per_second)) |
| dtype = _dtype_from_config(pipe.config.get("precision", defaults.get("precision"))) |
| conditioning = [ |
| { |
| "caption": item.caption, |
| "seconds_start": item.seconds_start, |
| "seconds_total": item.seconds_total, |
| } |
| for item in items |
| ] |
| with _autocast(pipe.device, dtype): |
| audio, _ = pipe.model.generate_audio( |
| conditioning=conditioning, |
| max_gen_len=max_gen_len, |
| cfg_scale=args.cfg_scale, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| temperature=args.temperature, |
| eos_token_id=pipe.model.lm.backbone.config.vocab_size - 2, |
| keep_sec_cond=args.keep_sec_cond, |
| ) |
| return normalize_batch_audio(audio.detach().cpu(), len(items)) |
|
|
|
|
| def build_argparser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Run multi-process BandTok batch inference with torchrun.") |
| parser.add_argument("--repo_id", default=str(REPO_ROOT), help="Hugging Face repo id or local repo directory.") |
| parser.add_argument("--test-config", required=True, help="YAML/JSON/JSONL/CSV/TXT prompt config.") |
| parser.add_argument("--output-dir", required=True, help="Directory for generated wav files.") |
| parser.add_argument("--batch-size", type=int, default=4, help="Conditions per forward call per process.") |
| parser.add_argument("--n-sample-per-cond", type=int, default=1) |
| parser.add_argument("--duration", type=float, default=None, help="Override generated audio duration in seconds.") |
| parser.add_argument("--second-start", type=float, default=None, help="Override seconds_start conditioning hyperparameter.") |
| parser.add_argument("--second-total", "--second-end", dest="second_total", type=float, default=None, help="Override seconds_total conditioning hyperparameter.") |
| parser.add_argument("--cfg-scale", type=float, default=2.0) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--top-p", type=float, default=0.6) |
| parser.add_argument("--max-gen-len", type=int, default=None) |
| parser.add_argument("--max-items", type=int, default=0, help="0 means all prompts.") |
| parser.add_argument("--seed", type=int, default=0, help="Base random seed. Use -1 to disable seeding.") |
| parser.add_argument("--keep-sec-cond", action="store_true") |
| parser.add_argument("--skip-existing", action="store_true") |
| parser.add_argument("--no-clip-duration", action="store_true", help="Do not trim each wav to its requested duration.") |
| return parser |
|
|
|
|
| def main() -> None: |
| args = build_argparser().parse_args() |
| if args.batch_size <= 0: |
| raise ValueError("--batch-size must be positive") |
| local_rank = init_distributed() |
| rank = get_rank() |
| world_size = get_world_size() |
| set_seed(args.seed, rank) |
|
|
| device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" |
| repo_id = str(Path(args.repo_id).expanduser().resolve()) if Path(args.repo_id).expanduser().exists() else args.repo_id |
| items = parse_items(args) |
| rank_items = items[rank::world_size] |
| output_dir = Path(args.output_dir).expanduser() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| rank0_print("=== BandTok MP Inference ===") |
| rank0_print(f"repo_id={repo_id}") |
| rank0_print(f"test_config={args.test_config}") |
| rank0_print(f"output_dir={output_dir}") |
| rank0_print(f"world_size={world_size}, batch_size(per-rank)={args.batch_size}") |
| rank0_print(f"total_items={len(items)}") |
|
|
| pipe = BandTokPipeline.from_pretrained(repo_id, device=device) |
| pipe.model.pretransform = pipe.model.pretransform.float() |
|
|
| for batch_idx, batch in enumerate(chunk_list(rank_items, args.batch_size)): |
| batch = [item for item in batch if not (args.skip_existing and (output_dir / f"{item.name}.wav").is_file())] |
| if not batch: |
| continue |
| audio = generate_batch(pipe, batch, args) |
| for item_idx, item in enumerate(batch): |
| item_audio = audio[item_idx] |
| if not args.no_clip_duration: |
| item_audio = item_audio[..., : int(item.duration * pipe.sample_rate)] |
| out_path = output_dir / f"{item.name}.wav" |
| save_audio(item_audio, out_path, pipe.sample_rate) |
| print(f"[Rank-{rank}] batch {batch_idx} done, items={len(batch)}") |
|
|
| cleanup_distributed(local_rank) |
| if rank == 0: |
| print("Generation finished.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|