| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import multiprocessing |
| import shutil |
| from collections import OrderedDict |
| from pathlib import Path |
| from pprint import pprint |
| from typing import Dict |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import seaborn as sns |
| import sox |
| from scipy.stats import expon |
| from tqdm import tqdm |
|
|
| from nemo.collections.asr.parts.utils.vad_utils import ( |
| get_nonspeech_segments, |
| load_speech_overlap_segments_from_rttm, |
| plot_sample_from_rttm, |
| ) |
|
|
| """ |
| This script analyzes multi-speaker speech dataset and generates statistics. |
| The input directory </path/to/rttm_and_wav_directory> is required to contain the following files: |
| - rttm files (*.rttm) |
| - wav files (*.wav) |
| |
| Usage: |
| python <NEMO_ROOT>/scripts/speaker_tasks/multispeaker_data_analysis.py \ |
| </path/to/rttm_and_wav_directory> \ |
| --session_dur 20 \ |
| --silence_mean 0.2 \ |
| --silence_var 100 \ |
| --overlap_mean 0.15 \ |
| --overlap_var 50 \ |
| --num_workers 8 \ |
| --num_samples 10 \ |
| --output_dir <path/to/output_directory> |
| """ |
|
|
|
|
| def process_sample(sess_dict: Dict) -> Dict: |
| """ |
| Process each synthetic sample |
| |
| Args: |
| sess_dict (dict): dictionary containing the following keys |
| rttm_file (str): path to the rttm file |
| session_dur (float): duration of the session (specified by argument) |
| precise (bool): whether to measure the precise duration of the session using sox |
| |
| Returns: |
| results (dict): dictionary containing the following keys |
| session_dur (float): duration of the session |
| silence_len_list (list): list of silence durations of each silence occurrence |
| silence_dur (float): total silence duration in a session |
| silence_ratio (float): ratio of silence duration to session duration |
| overlap_len_list (list): list of overlap durations of each overlap occurrence |
| overlap_dur (float): total overlap duration |
| overlap_ratio (float): ratio of overlap duration to speech (non-silence) duration |
| """ |
|
|
| rttm_file = sess_dict["rttm_file"] |
| session_dur = sess_dict["session_dur"] |
| precise = sess_dict["precise"] |
| if precise or session_dur is None: |
| wav_file = rttm_file.parent / Path(rttm_file.stem + ".wav") |
| session_dur = sox.file_info.duration(str(wav_file)) |
|
|
| speech_seg, overlap_seg = load_speech_overlap_segments_from_rttm(rttm_file) |
| speech_dur = sum([sess_dict[1] - sess_dict[0] for sess_dict in speech_seg]) |
|
|
| silence_seg = get_nonspeech_segments(speech_seg, session_dur) |
| silence_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in silence_seg] |
| silence_dur = max(0, session_dur - speech_dur) |
| silence_ratio = silence_dur / session_dur |
|
|
| overlap_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in overlap_seg] |
| overlap_dur = sum(overlap_len_list) if len(overlap_len_list) else 0 |
| overlap_ratio = overlap_dur / speech_dur |
|
|
| results = { |
| "session_dur": session_dur, |
| "silence_len_list": silence_len_list, |
| "silence_dur": silence_dur, |
| "silence_ratio": silence_ratio, |
| "overlap_len_list": overlap_len_list, |
| "overlap_dur": overlap_dur, |
| "overlap_ratio": overlap_ratio, |
| } |
|
|
| return results |
|
|
|
|
| def run_multispeaker_data_analysis( |
| input_dir, |
| session_dur=None, |
| silence_mean=None, |
| silence_var=None, |
| overlap_mean=None, |
| overlap_var=None, |
| precise=False, |
| save_path=None, |
| num_workers=1, |
| ) -> Dict: |
| rttm_list = list(Path(input_dir).glob("*.rttm")) |
| """ |
| Analyze the multispeaker data and plot the distribution of silence and overlap durations. |
| |
| Args: |
| input_dir (str): path to the directory containing the rttm files |
| session_dur (float): duration of the session (specified by argument) |
| silence_mean (float): mean of the silence duration distribution |
| silence_var (float): variance of the silence duration distribution |
| overlap_mean (float): mean of the overlap duration distribution |
| overlap_var (float): variance of the overlap duration distribution |
| precise (bool): whether to measure the precise duration of the session using sox |
| save_path (str): path to save the plots |
| |
| Returns: |
| stats (dict): dictionary containing the statistics of the analyzed data |
| """ |
|
|
| print(f"Found {len(rttm_list)} files to be processed") |
| if len(rttm_list) == 0: |
| raise ValueError(f"No rttm files found in {input_dir}") |
|
|
| silence_duration = 0.0 |
| total_duration = 0.0 |
| overlap_duration = 0.0 |
|
|
| silence_ratio_all = [] |
| overlap_ratio_all = [] |
| silence_length_all = [] |
| overlap_length_all = [] |
|
|
| queue = [] |
| for rttm_file in tqdm(rttm_list): |
| queue.append( |
| {"rttm_file": rttm_file, "session_dur": session_dur, "precise": precise,} |
| ) |
|
|
| if num_workers <= 1: |
| results = [process_sample(sess_dict) for sess_dict in tqdm(queue)] |
| else: |
| with multiprocessing.Pool(processes=num_workers) as p: |
| results = list(tqdm(p.imap(process_sample, queue), total=len(queue), desc='Processing', leave=True,)) |
|
|
| for item in results: |
| total_duration += item["session_dur"] |
| silence_duration += item["silence_dur"] |
| overlap_duration += item["overlap_dur"] |
|
|
| silence_length_all += item["silence_len_list"] |
| overlap_length_all += item["overlap_len_list"] |
|
|
| silence_ratio_all.append(item["silence_ratio"]) |
| overlap_ratio_all.append(item["overlap_ratio"]) |
|
|
| actual_silence_mean = silence_duration / total_duration |
| actual_silence_var = np.var(silence_ratio_all) |
| actual_overlap_mean = overlap_duration / (total_duration - silence_duration) |
| actual_overlap_var = np.var(overlap_ratio_all) |
|
|
| stats = OrderedDict() |
| stats["total duration (hours)"] = f"{total_duration / 3600:.2f}" |
| stats["number of sessions"] = len(rttm_list) |
| stats["average session duration (seconds)"] = f"{total_duration / len(rttm_list):.2f}" |
| stats["actual silence ratio mean/var"] = f"{actual_silence_mean:.4f}/{actual_silence_var:.4f}" |
| stats["actual overlap ratio mean/var"] = f"{actual_overlap_mean:.4f}/{actual_overlap_var:.4f}" |
| stats["expected silence ratio mean/var"] = f"{silence_mean}/{silence_var}" |
| stats["expected overlap ratio mean/var"] = f"{overlap_mean}/{overlap_var}" |
| stats["save_path"] = save_path |
|
|
| print("-----------------------------------------------") |
| print(" Results ") |
| print("-----------------------------------------------") |
| for k, v in stats.items(): |
| print(k, ": ", v) |
| print("-----------------------------------------------") |
|
|
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 14)) |
| fig.suptitle( |
| f"Average session={total_duration/len(rttm_list):.2f} seconds, num sessions={len(rttm_list)}, total={total_duration/3600:.2f} hours" |
| ) |
| sns.histplot(silence_ratio_all, ax=ax1) |
| ax1.set_xlabel("Silence ratio in a session") |
| ax1.set_title( |
| f"Target silence mean={silence_mean}, var={silence_var}. \nActual silence ratio={actual_silence_mean:.4f}, var={actual_silence_var:.4f}" |
| ) |
|
|
| _, scale = expon.fit(silence_length_all, floc=0) |
| sns.histplot(silence_length_all, ax=ax2) |
| ax2.set_xlabel("Per-silence length in seconds") |
| ax2.set_title(f"Per-silence length histogram, \nfitted exponential distribution with mean={scale:.4f}") |
|
|
| sns.histplot(overlap_ratio_all, ax=ax3) |
| ax3.set_title( |
| f"Target overlap mean={overlap_mean}, var={overlap_var}. \nActual ratio={actual_overlap_mean:.4f}, var={actual_overlap_var:.4f}" |
| ) |
| ax3.set_xlabel("Overlap ratio in a session") |
| _, scale2 = expon.fit(overlap_length_all, floc=0) |
| sns.histplot(overlap_length_all, ax=ax4) |
| ax4.set_title(f"Per overlap length histogram, \nfitted exponential distribution with mean={scale2:.4f}") |
| ax4.set_xlabel("Duration in seconds") |
|
|
| if save_path: |
| fig.savefig(save_path) |
| print(f"Figure saved at: {save_path}") |
|
|
| return stats |
|
|
|
|
| def visualize_multispeaker_data(input_dir: str, output_dir: str, num_samples: int = 10) -> None: |
| """ |
| Visualize a set of randomly sampled data in the input directory |
| |
| Args: |
| input_dir (str): Path to the input directory |
| output_dir (str): Path to the output directory |
| num_samples (int): Number of samples to visualize |
| """ |
| rttm_list = list(Path(input_dir).glob("*.rttm")) |
| idx_list = np.random.permutation(len(rttm_list))[:num_samples] |
| print(f"Visualizing {num_samples} random samples") |
| for idx in idx_list: |
| rttm_file = rttm_list[idx] |
| audio_file = rttm_file.parent / Path(rttm_file.stem + ".wav") |
| output_file = Path(output_dir) / Path(rttm_file.stem + ".png") |
| plot_sample_from_rttm(audio_file=audio_file, rttm_file=rttm_file, save_path=str(output_file), show=False) |
| print(f"Sample plots saved at: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("input_dir", default="", help="Input directory") |
| parser.add_argument("-sd", "--session_dur", default=None, type=float, help="Duration per session in seconds") |
| parser.add_argument("-sm", "--silence_mean", default=None, type=float, help="Expected silence ratio mean") |
| parser.add_argument("-sv", "--silence_var", default=None, type=float, help="Expected silence ratio variance") |
| parser.add_argument("-om", "--overlap_mean", default=None, type=float, help="Expected overlap ratio mean") |
| parser.add_argument("-ov", "--overlap_var", default=None, type=float, help="Expected overlap ratio variance") |
| parser.add_argument("-w", "--num_workers", default=1, type=int, help="Number of CPU workers to use") |
| parser.add_argument("-s", "--num_samples", default=10, type=int, help="Number of random samples to plot") |
| parser.add_argument("-o", "--output_dir", default="analysis/", type=str, help="Directory for saving output figure") |
| parser.add_argument( |
| "--precise", action="store_true", help="Set to get precise duration, with significant time cost" |
| ) |
| args = parser.parse_args() |
|
|
| print("Running with params:") |
| pprint(vars(args)) |
|
|
| output_dir = Path(args.output_dir) |
| if output_dir.exists(): |
| print(f"Removing existing output directory: {args.output_dir}") |
| shutil.rmtree(str(output_dir)) |
| output_dir.mkdir(parents=True) |
|
|
| run_multispeaker_data_analysis( |
| input_dir=args.input_dir, |
| session_dur=args.session_dur, |
| silence_mean=args.silence_mean, |
| silence_var=args.silence_var, |
| overlap_mean=args.overlap_mean, |
| overlap_var=args.overlap_var, |
| precise=args.precise, |
| save_path=str(Path(args.output_dir, "statistics.png")), |
| num_workers=args.num_workers, |
| ) |
|
|
| visualize_multispeaker_data(input_dir=args.input_dir, output_dir=args.output_dir, num_samples=args.num_samples) |
|
|
| print("The multispeaker data analysis has been completed.") |
| print(f"Please check the output directory: \n{args.output_dir}") |
|
|