| import numpy as np
|
| import librosa
|
| import torch
|
| import matplotlib.pyplot as plt
|
|
|
| from metrics.event_based_metrics import event_metrics
|
| from src.audio_preprocessing import readLabels, object_padding, fbank_features_extraction
|
| from tslearn.clustering import TimeSeriesKMeans, KShape
|
| from tslearn.metrics import dtw
|
| from sklearn.cluster import KMeans, AffinityPropagation, AgglomerativeClustering, MeanShift, estimate_bandwidth, DBSCAN, OPTICS, Birch
|
| from sklearn.metrics import accuracy_score, f1_score
|
|
|
| import os
|
| os.environ["OMP_NUM_THREADS"] = '3'
|
|
|
| def standardize_array(array):
|
| mean = np.mean(array, axis=0)
|
| std = np.std(array, axis=0)
|
|
|
|
|
| std[std == 0] = 1
|
| standardized_array = (array - mean) / std
|
| return standardized_array
|
|
|
|
|
| def kmeans_clustering(audio_data, n_clusters=2):
|
| kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(audio_data)
|
| labels = kmeans.predict(audio_data)
|
| return labels
|
|
|
| def dtw_kmedoids_clustering(audio_data, n_clusters=2):
|
| km = TimeSeriesKMeans(n_clusters=n_clusters, metric="dtw", max_iter=10, random_state=42)
|
| labels = km.fit_predict(audio_data)
|
| return labels
|
|
|
| def kshape_clustering(audio_data, n_clusters=2):
|
| ks = KShape(n_clusters=n_clusters, max_iter=10, random_state=42)
|
| labels = ks.fit_predict(audio_data)
|
| return labels
|
|
|
| def affinity_propagation_clustering(audio_data):
|
| af = AffinityPropagation(random_state=42)
|
| labels = af.fit_predict(audio_data)
|
| return labels
|
|
|
| def agglomerative_clustering(audio_data, n_clusters=2):
|
| agg = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average')
|
| distances = [[dtw(x, y) for y in audio_data] for x in audio_data]
|
| labels = agg.fit_predict(distances)
|
| return labels
|
|
|
| def mean_shift_clustering(audio_data):
|
| bandwidth = estimate_bandwidth(audio_data, quantile=0.2)
|
| ms = MeanShift(bandwidth=bandwidth, bin_seeding=True, cluster_all=False)
|
| labels = ms.fit_predict(audio_data)
|
| return labels
|
|
|
| def bisecting_kmeans_clustering(audio_data, n_clusters=2):
|
| clusters = [audio_data]
|
| while len(clusters) < n_clusters:
|
| largest_cluster_idx = max(range(len(clusters)), key=lambda i: len(clusters[i]))
|
| largest_cluster = clusters[largest_cluster_idx]
|
|
|
| km = TimeSeriesKMeans(n_clusters=2, metric="dtw", max_iter=10, random_state=42)
|
| sub_labels = km.fit_predict(largest_cluster)
|
| sub_cluster1 = largest_cluster[sub_labels == 0]
|
| sub_cluster2 = largest_cluster[sub_labels == 1]
|
|
|
| clusters.pop(largest_cluster_idx)
|
| clusters.append(sub_cluster1)
|
| clusters.append(sub_cluster2)
|
|
|
| labels = [-1] * len(audio_data)
|
| for i, cluster in enumerate(clusters):
|
| for idx in [j for j, x in enumerate(audio_data) if x in cluster]:
|
| labels[idx] = i
|
|
|
| return np.array(labels)
|
|
|
| def dbscan_clustering(audio_data, eps=0.5, min_samples=5):
|
| dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric=dtw)
|
| labels = dbscan.fit_predict(audio_data)
|
| return labels
|
|
|
| def optics_clustering(audio_data, min_samples=5):
|
| optics = OPTICS(min_samples=min_samples, metric=dtw, cluster_method='xi')
|
| labels = optics.fit_predict(audio_data)
|
| return labels
|
|
|
| def birch_clustering(audio_data, n_clusters=None, branching_factor=50, threshold=0.5):
|
| birch = Birch(n_clusters=n_clusters, branching_factor=branching_factor, threshold=threshold)
|
| labels = birch.fit_predict(audio_data)
|
| return labels
|
|
|
| def clustering_predicting(model, annotation_file, audio_file, max_length, clustering_method="kmeans", k=2):
|
| signal, fs = librosa.load(audio_file)
|
| signal = object_padding(signal, max_length)
|
| truth_labels = readLabels(path=annotation_file, sample_rate=fs)
|
| truth_labels = object_padding(truth_labels, max_length)
|
|
|
| test_audio = fbank_features_extraction([audio_file], max_length)
|
|
|
| test_input = torch.tensor(test_audio, dtype=torch.float32)
|
| x_recon, test_latent, test_u, loss = model(test_input)
|
|
|
| clustering_input = standardize_array(test_u.reshape((703, -1)).detach().numpy())
|
| clustering_label = None
|
|
|
| if clustering_method == "kmeans":
|
| clustering_label = kmeans_clustering(clustering_input, n_clusters=k)
|
| elif clustering_method == "dtw":
|
| clustering_label = dtw_kmedoids_clustering(clustering_input, n_clusters=k)
|
| elif clustering_method == "kshape":
|
| clustering_label = kshape_clustering(clustering_input, n_clusters=k)
|
| elif clustering_method == "affinity":
|
| clustering_label = affinity_propagation_clustering(clustering_input)
|
| elif clustering_method == "agglomerative":
|
| clustering_label = agglomerative_clustering(clustering_input, n_clusters=k)
|
| elif clustering_method == "mean_shift":
|
| clustering_label = mean_shift_clustering(clustering_input)
|
| elif clustering_method == "bisecting":
|
| clustering_label = bisecting_kmeans_clustering(clustering_input, n_clusters=k)
|
| elif clustering_method == "DBSCAN":
|
| clustering_label = dbscan_clustering(clustering_input)
|
| elif clustering_method == "OPTICS":
|
| clustering_label = optics_clustering(clustering_input)
|
| elif clustering_method == "Birch":
|
| clustering_label = birch_clustering(clustering_input, n_clusters=k)
|
|
|
| label_timeseries = np.zeros(max_length)
|
| begin = int(0)
|
| end = int(0.025 *fs)
|
| shift_step = int(0.01 * fs)
|
| for i in range(clustering_label.shape[0]):
|
| label_timeseries[begin:end] = abs(clustering_label[i])
|
| begin = begin + shift_step
|
| end = end + shift_step
|
|
|
| return signal, fs, np.array(truth_labels), label_timeseries
|
|
|
| def signal_visualization(signal, fs, truth_labels, label_timeseries):
|
|
|
| Ns = len(signal)
|
| Ts = 1 / fs
|
| t = np.arange(Ns) * Ts
|
| norm_coef = 1.1 * np.max(signal)
|
| edge_ind = np.min([signal.shape[0], len(truth_labels)])
|
|
|
| plt.figure(figsize=(24, 6))
|
| plt.plot(t[:edge_ind], signal[:edge_ind])
|
| plt.plot(t[:edge_ind], truth_labels[:edge_ind] * norm_coef)
|
| plt.plot(t[:edge_ind], label_timeseries[:edge_ind] * norm_coef)
|
|
|
| plt.title("Ground truth labels")
|
| plt.legend(['Signal', 'Cry', 'Clusters'])
|
| plt.show()
|
|
|
| def cluster_visualization(signal, fs, truth_labels, label_timeseries):
|
|
|
| Ns = len(signal)
|
| Ts = 1 / fs
|
| t = np.arange(Ns) * Ts
|
| norm_coef = 1.1 * np.max(signal)
|
| edge_ind = np.min([signal.shape[0], len(truth_labels)])
|
|
|
| plt.figure(figsize=(24, 6))
|
| line_signal, = plt.plot(t[:edge_ind], signal[:edge_ind])
|
|
|
|
|
| cry_indices = np.where(truth_labels == 1)[0]
|
| non_cry_indices = np.where(truth_labels == 0)[0]
|
|
|
|
|
|
|
| start_cry = np.insert(np.where(np.diff(cry_indices) != 1)[0] + 1, 0, 0)
|
| end_cry = np.append(np.where(np.diff(cry_indices) != 1)[0], len(cry_indices) - 1)
|
|
|
|
|
| for start, end in zip(start_cry, end_cry):
|
| plt.fill_between(
|
| t[cry_indices[start:end+1]],
|
| 0,
|
| norm_coef,
|
| color='orange',
|
| alpha=0.5,
|
| label='Cry' if start == start_cry[0] else None
|
| )
|
| legend_handles = []
|
| legend_handles.append(plt.Rectangle((0, 0), 1, 1, color='orange', alpha=0.5))
|
|
|
| start_non_cry = np.insert(np.where(np.diff(non_cry_indices) != 1)[0] + 1, 0, 0)
|
| end_non_cry = np.append(np.where(np.diff(non_cry_indices) != 1)[0], len(non_cry_indices) - 1)
|
|
|
|
|
| for start, end in zip(start_non_cry, end_non_cry):
|
| plt.fill_between(
|
| t[non_cry_indices[start:end+1]],
|
| 0,
|
| norm_coef,
|
| color='gray',
|
| alpha=0.5,
|
| label='Non-cry' if start == start_non_cry[0] else None
|
| )
|
| legend_handles.append(plt.Rectangle((0, 0), 1, 1, color='gray', alpha=0.5))
|
|
|
|
|
| unique_labels = np.unique(label_timeseries)
|
| cmap = plt.get_cmap('tab10')
|
|
|
|
|
| for i, label in enumerate(unique_labels):
|
| label_indices = np.where(label_timeseries == label)[0]
|
|
|
|
|
| start_indices = np.insert(np.where(np.diff(label_indices) != 1)[0] + 1, 0, 0)
|
| end_indices = np.append(np.where(np.diff(label_indices) != 1)[0], len(label_indices) - 1)
|
|
|
|
|
| for start, end in zip(start_indices, end_indices):
|
| plt.fill_between(
|
| t[label_indices[start:end+1]],
|
| 0,
|
| -norm_coef,
|
| color=cmap(i),
|
| alpha=0.5,
|
| label=f'Cluster {label}' if start == start_indices[0] else None
|
| )
|
| legend_handles.append(plt.Rectangle((0, 0), 1, 1, color=cmap(i), alpha=0.5))
|
|
|
| plt.title("Audio Clustering")
|
| plt.legend(
|
| [line_signal] + legend_handles,
|
| ['Signal'] + ['Cry', 'Non-Cry'] + [f'Cluster {label}' for label in unique_labels]
|
| )
|
| plt.show()
|
|
|
| def clustering_evaluatation(model, max_length, audio_files, annotation_files, domain_index=None, clustering_method="kmeans", k=2):
|
| acc_list, framef_list, eventf_list, iou_list = [], [], [], []
|
| switch_list = []
|
| if domain_index is None:
|
| domain_index = range(len(audio_files))
|
| for i in domain_index:
|
| annotation_file = annotation_files[i]
|
| audio_file = audio_files[i]
|
| clustering_switch = False
|
| _, _, truth_labels, label_timeseries = clustering_predicting(model, annotation_file, audio_file, max_length, clustering_method, k)
|
| temp_accuracy = accuracy_score(truth_labels, label_timeseries)
|
| framef = max(f1_score(1 - label_timeseries > 0, truth_labels), f1_score(label_timeseries > 0, truth_labels))
|
| if temp_accuracy < 0.5:
|
| clustering_accuracy = 1-temp_accuracy
|
| clustering_switch = True
|
| else:
|
| clustering_accuracy = temp_accuracy
|
| acc_list.append(clustering_accuracy)
|
| switch_list.append(clustering_switch)
|
| framef_list.append(framef)
|
| eventf, iou, _, _, _ = event_metrics(truth_labels, label_timeseries, tolerance=2000, overlap_threshold=0.75)
|
| eventf_list.append(eventf)
|
| iou_list.append(iou)
|
|
|
| return acc_list, framef_list, eventf_list, iou_list, switch_list |