moudook commited on
Commit
7033335
·
verified ·
1 Parent(s): ec51c17

Upload train_speaker_id.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_speaker_id.py +391 -0
train_speaker_id.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Identification using PCA and Classical ML Models
3
+ ========================================================
4
+ Analyzes ECAPA embeddings using PCA and evaluates:
5
+ - Logistic Regression
6
+ - SVM (Linear)
7
+ - SVM (RBF/Gaussian)
8
+ - k-Nearest Neighbors (k-NN)
9
+
10
+ Deliverables:
11
+ - PCA visualization plots (2D)
12
+ - Accuracy comparison table (all models x PCA dims)
13
+ - Precision, Recall, F1, Confusion Matrices
14
+ - Trained ML models (saved with joblib)
15
+ """
16
+
17
+ import os
18
+ import time
19
+ import json
20
+ from pathlib import Path
21
+
22
+ import matplotlib
23
+ matplotlib.use("Agg") # Non-interactive backend for server
24
+ import matplotlib.pyplot as plt
25
+ import numpy as np
26
+ import pandas as pd
27
+ import seaborn as sns
28
+ from joblib import dump
29
+ from sklearn.decomposition import PCA
30
+ from sklearn.linear_model import LogisticRegression
31
+ from sklearn.metrics import (
32
+ accuracy_score,
33
+ confusion_matrix,
34
+ f1_score,
35
+ precision_score,
36
+ recall_score,
37
+ )
38
+ from sklearn.model_selection import train_test_split
39
+ from sklearn.neighbors import KNeighborsClassifier
40
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
41
+ from sklearn.svm import SVC
42
+ from tqdm.auto import tqdm
43
+
44
+ # ============================================================
45
+ # Configuration
46
+ # ============================================================
47
+ RANDOM_STATE = 42
48
+ TEST_SIZE = 0.1 # 10% for final test
49
+ VAL_SIZE = 0.1111 # ~10% of remaining (0.1111 * 0.9 ≈ 0.10)
50
+ DATA_PATH = "voxceleb1_dev_ecapa_features.csv"
51
+ OUTPUT_DIR = Path("results")
52
+ MODELS_DIR = OUTPUT_DIR / "models"
53
+ PLOTS_DIR = OUTPUT_DIR / "plots"
54
+
55
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
56
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
57
+ PLOTS_DIR.mkdir(parents=True, exist_ok=True)
58
+
59
+ print("=" * 60)
60
+ print("Speaker Identification - PCA + ML Pipeline")
61
+ print("=" * 60)
62
+
63
+
64
+ # ============================================================
65
+ # 1. Load Data
66
+ # ============================================================
67
+ print("\n[1/8] Loading dataset...")
68
+ t0 = time.time()
69
+ df = pd.read_csv(DATA_PATH)
70
+ feature_cols = [c for c in df.columns if c.startswith("emb_")]
71
+ print(f" Dataset shape: {df.shape}")
72
+ print(f" Features: {len(feature_cols)}-dim ECAPA embeddings")
73
+ print(f" Unique speakers: {df['speaker_id'].nunique()}")
74
+ print(f" Load time: {time.time() - t0:.1f}s")
75
+
76
+
77
+ # ============================================================
78
+ # 2. Train / Validation / Test Split (80/10/10)
79
+ # ============================================================
80
+ print("\n[2/8] Splitting data 80/10/10 (speaker-stratified)...")
81
+ t0 = time.time()
82
+
83
+ # First split: 90% train+val, 10% test
84
+ df_trainval, df_test = train_test_split(
85
+ df,
86
+ test_size=TEST_SIZE,
87
+ random_state=RANDOM_STATE,
88
+ stratify=df["speaker_id"],
89
+ )
90
+
91
+ # Second split: 80% train, 10% val (from the 90%)
92
+ df_train, df_val = train_test_split(
93
+ df_trainval,
94
+ test_size=VAL_SIZE,
95
+ random_state=RANDOM_STATE,
96
+ stratify=df_trainval["speaker_id"],
97
+ )
98
+
99
+ print(f" Train: {len(df_train)} ({len(df_train)/len(df)*100:.1f}%)")
100
+ print(f" Val: {len(df_val)} ({len(df_val)/len(df)*100:.1f}%)")
101
+ print(f" Test: {len(df_test)} ({len(df_test)/len(df)*100:.1f}%)")
102
+ print(f" Split time: {time.time() - t0:.1f}s")
103
+
104
+ # Encode labels
105
+ le = LabelEncoder()
106
+ le.fit(df["speaker_id"])
107
+
108
+ X_train = df_train[feature_cols].values
109
+ X_val = df_val[feature_cols].values
110
+ X_test = df_test[feature_cols].values
111
+
112
+ y_train_enc = le.transform(df_train["speaker_id"])
113
+ y_val_enc = le.transform(df_val["speaker_id"])
114
+ y_test_enc = le.transform(df_test["speaker_id"])
115
+
116
+ num_classes = len(le.classes_)
117
+ print(f" Number of classes (speakers): {num_classes}")
118
+
119
+
120
+ # ============================================================
121
+ # 3. Standardize Features
122
+ # ============================================================
123
+ print("\n[3/8] Standardizing features...")
124
+ t0 = time.time()
125
+ scaler = StandardScaler()
126
+ X_train_sc = scaler.fit_transform(X_train)
127
+ X_val_sc = scaler.transform(X_val)
128
+ X_test_sc = scaler.transform(X_test)
129
+ print(f" Scaled train shape: {X_train_sc.shape}")
130
+ print(f" Scale time: {time.time() - t0:.1f}s")
131
+
132
+
133
+ # ============================================================
134
+ # 4. PCA Transformation (192, 100, 50, 2)
135
+ # ============================================================
136
+ print("\n[4/8] Applying PCA...")
137
+ t0 = time.time()
138
+
139
+ pca_100 = PCA(n_components=100, random_state=RANDOM_STATE)
140
+ pca_50 = PCA(n_components=50, random_state=RANDOM_STATE)
141
+ pca_2 = PCA(n_components=2, random_state=RANDOM_STATE)
142
+
143
+ # Fit on train, transform all
144
+ X_train_pca100 = pca_100.fit_transform(X_train_sc)
145
+ X_val_pca100 = pca_100.transform(X_val_sc)
146
+ X_test_pca100 = pca_100.transform(X_test_sc)
147
+
148
+ X_train_pca50 = pca_50.fit_transform(X_train_sc)
149
+ X_val_pca50 = pca_50.transform(X_val_sc)
150
+ X_test_pca50 = pca_50.transform(X_test_sc)
151
+
152
+ X_train_pca2 = pca_2.fit_transform(X_train_sc)
153
+ X_val_pca2 = pca_2.transform(X_val_sc)
154
+ X_test_pca2 = pca_2.transform(X_test_sc)
155
+
156
+ var_100 = pca_100.explained_variance_ratio_.sum()
157
+ var_50 = pca_50.explained_variance_ratio_.sum()
158
+ var_2 = pca_2.explained_variance_ratio_.sum()
159
+
160
+ print(f" PCA 100 explained variance: {var_100:.4f}")
161
+ print(f" PCA 50 explained variance: {var_50:.4f}")
162
+ print(f" PCA 2 explained variance: {var_2:.4f}")
163
+ print(f" PCA time: {time.time() - t0:.1f}s")
164
+
165
+
166
+ # ============================================================
167
+ # 5. PCA 2D Visualization
168
+ # ============================================================
169
+ print("\n[5/8] Generating PCA 2D visualization...")
170
+ num_speakers = len(np.unique(y_train_enc))
171
+ cmap = plt.cm.get_cmap("nipy_spectral", num_speakers)
172
+
173
+ fig, ax = plt.subplots(figsize=(14, 10))
174
+ scatter = ax.scatter(
175
+ X_train_pca2[:, 0], X_train_pca2[:, 1],
176
+ c=y_train_enc, cmap=cmap, alpha=0.45, s=8,
177
+ linewidths=0, rasterized=True, marker="o",
178
+ )
179
+ ax.set_title("2D PCA Projection of ECAPA Embeddings (Train Set)", fontsize=16)
180
+ ax.set_xlabel(f"PC1 ({pca_2.explained_variance_ratio_[0] * 100:.2f}% variance)", fontsize=13)
181
+ ax.set_ylabel(f"PC2 ({pca_2.explained_variance_ratio_[1] * 100:.2f}% variance)", fontsize=13)
182
+ ax.grid(True, linestyle="--", alpha=0.3)
183
+ plt.tight_layout()
184
+ pca_plot_path = PLOTS_DIR / "pca_2d_visualization.png"
185
+ fig.savefig(pca_plot_path, dpi=150)
186
+ plt.close(fig)
187
+ print(f" Saved: {pca_plot_path}")
188
+
189
+
190
+ # ============================================================
191
+ # 6. Train Models
192
+ # ============================================================
193
+ print("\n[6/8] Training models...")
194
+ models = {}
195
+
196
+ # Define model configs: name -> (model_instance, feature_sets)
197
+ # Feature sets: "192" = original, "100" = PCA100, "50" = PCA50
198
+ feature_sets = {
199
+ "192": (X_train_sc, X_val_sc, X_test_sc),
200
+ "100": (X_train_pca100, X_val_pca100, X_test_pca100),
201
+ "50": (X_train_pca50, X_val_pca50, X_test_pca50),
202
+ }
203
+
204
+ model_defs = {
205
+ "Logistic Regression": [
206
+ LogisticRegression(max_iter=2000, solver="lbfgs", n_jobs=-1, random_state=RANDOM_STATE, verbose=0),
207
+ ],
208
+ "SVM (Linear)": [
209
+ SVC(kernel="linear", C=1.0, random_state=RANDOM_STATE),
210
+ ],
211
+ "SVM (RBF)": [
212
+ SVC(kernel="rbf", C=1.0, gamma="scale", random_state=RANDOM_STATE),
213
+ ],
214
+ "k-NN": [
215
+ KNeighborsClassifier(n_neighbors=5, metric="minkowski", n_jobs=-1),
216
+ ],
217
+ }
218
+
219
+ results = {}
220
+
221
+ for model_name, model_list in model_defs.items():
222
+ print(f"\n --- {model_name} ---")
223
+ for model in model_list:
224
+ for dim_name, (X_tr, X_va, X_te) in feature_sets.items():
225
+ key = f"{model_name}_{dim_name}"
226
+ print(f" Training {key} ...", end=" ", flush=True)
227
+ t_train = time.time()
228
+ model_clone = type(model)(**model.get_params())
229
+ model_clone.fit(X_tr, y_train_enc)
230
+ train_time = time.time() - t_train
231
+
232
+ # Evaluate on test set
233
+ t_pred = time.time()
234
+ y_pred = model_clone.predict(X_te)
235
+ pred_time = time.time() - t_pred
236
+
237
+ acc = accuracy_score(y_test_enc, y_pred)
238
+ prec = precision_score(y_test_enc, y_pred, average="macro", zero_division=0)
239
+ rec = recall_score(y_test_enc, y_pred, average="macro", zero_division=0)
240
+ f1 = f1_score(y_test_enc, y_pred, average="macro", zero_division=0)
241
+ cm = confusion_matrix(y_test_enc, y_pred)
242
+
243
+ results[key] = {
244
+ "accuracy": acc,
245
+ "precision_macro": prec,
246
+ "recall_macro": rec,
247
+ "f1_macro": f1,
248
+ "train_time_s": train_time,
249
+ "pred_time_s": pred_time,
250
+ "confusion_matrix": cm.tolist(),
251
+ }
252
+
253
+ # Save model
254
+ model_path = MODELS_DIR / f"{key.replace(' ', '_').replace('(', '').replace(')', '')}.joblib"
255
+ dump(model_clone, model_path)
256
+
257
+ print(f"acc={acc:.4f} prec={prec:.4f} rec={rec:.4f} f1={f1:.4f} "
258
+ f"train={train_time:.1f}s pred={pred_time:.1f}s")
259
+
260
+
261
+ # ============================================================
262
+ # 7. Save Results
263
+ # ============================================================
264
+ print("\n[7/8] Saving results...")
265
+
266
+ # 7a. Accuracy comparison table
267
+ acc_table = pd.DataFrame([
268
+ {
269
+ "Model": model_name,
270
+ "Original (192)": results.get(f"{model_name}_192", {}).get("accuracy", None),
271
+ "PCA (100)": results.get(f"{model_name}_100", {}).get("accuracy", None),
272
+ "PCA (50)": results.get(f"{model_name}_50", {}).get("accuracy", None),
273
+ }
274
+ for model_name in model_defs.keys()
275
+ ])
276
+ acc_table_path = OUTPUT_DIR / "accuracy_comparison_table.csv"
277
+ acc_table.to_csv(acc_table_path, index=False)
278
+ print(f"\n Accuracy Comparison Table:")
279
+ print(acc_table.to_string(index=False))
280
+ print(f" Saved: {acc_table_path}")
281
+
282
+ # 7b. Full results JSON (all metrics)
283
+ results_path = OUTPUT_DIR / "full_results.json"
284
+ with open(results_path, "w") as f:
285
+ json.dump(results, f, indent=2)
286
+ print(f" Saved: {results_path}")
287
+
288
+ # 7c. PCA explained variance
289
+ pca_var_df = pd.DataFrame({
290
+ "PCA Dimension": [100, 50, 2],
291
+ "Explained Variance Ratio": [var_100, var_50, var_2],
292
+ })
293
+ pca_var_path = OUTPUT_DIR / "pca_explained_variance.csv"
294
+ pca_var_df.to_csv(pca_var_path, index=False)
295
+ print(f" Saved: {pca_var_path}")
296
+
297
+
298
+ # ============================================================
299
+ # 8. Visualizations
300
+ # ============================================================
301
+ print("\n[8/8] Generating visualizations...")
302
+
303
+ # 8a. Accuracy bar chart
304
+ fig, ax = plt.subplots(figsize=(12, 7))
305
+ x = np.arange(len(model_defs))
306
+ width = 0.25
307
+ colors = ["#2196F3", "#4CAF50", "#FF9800"]
308
+
309
+ for i, dim in enumerate(["192", "100", "50"]):
310
+ accs = [results.get(f"{m}_{dim}", {}).get("accuracy", 0) for m in model_defs]
311
+ ax.bar(x + i * width, accs, width, label=f"PCA ({dim})", color=colors[i])
312
+
313
+ ax.set_xlabel("Model", fontsize=13)
314
+ ax.set_ylabel("Accuracy", fontsize=13)
315
+ ax.set_title("Classification Accuracy by Model and PCA Dimensionality", fontsize=15)
316
+ ax.set_xticks(x + width)
317
+ ax.set_xticklabels(list(model_defs.keys()), rotation=15, ha="right")
318
+ ax.set_ylim(0.90, 1.0)
319
+ ax.legend()
320
+ ax.grid(axis="y", linestyle="--", alpha=0.4)
321
+ plt.tight_layout()
322
+ acc_bar_path = PLOTS_DIR / "accuracy_comparison_bar.png"
323
+ fig.savefig(acc_bar_path, dpi=150)
324
+ plt.close(fig)
325
+ print(f" Saved: {acc_bar_path}")
326
+
327
+ # 8b. Confusion matrices (for best model = Logistic Regression PCA 100)
328
+ best_key = "Logistic Regression_100"
329
+ best_cm = np.array(results[best_key]["confusion_matrix"])
330
+
331
+ # For large number of classes, show a summary or top classes
332
+ if best_cm.shape[0] > 50:
333
+ # Show a subset or normalized version
334
+ fig, ax = plt.subplots(figsize=(14, 12))
335
+ # Normalize row-wise
336
+ cm_norm = best_cm.astype(float) / (best_cm.sum(axis=1, keepdims=True) + 1e-10)
337
+ # For very large matrices, show a sample
338
+ sample_size = min(50, best_cm.shape[0])
339
+ indices = np.linspace(0, best_cm.shape[0] - 1, sample_size, dtype=int)
340
+ cm_sample = cm_norm[np.ix_(indices, indices)]
341
+ sns.heatmap(cm_sample, ax=ax, cmap="Blues", cbar_kws={"label": "Proportion"})
342
+ ax.set_title(f"Confusion Matrix (Normalized) - {best_key} (sample {sample_size}x{sample_size})", fontsize=14)
343
+ ax.set_xlabel("Predicted Speaker", fontsize=12)
344
+ ax.set_ylabel("True Speaker", fontsize=12)
345
+ else:
346
+ fig, ax = plt.subplots(figsize=(10, 8))
347
+ sns.heatmap(best_cm, ax=ax, cmap="Blues", fmt="d")
348
+ ax.set_title(f"Confusion Matrix - {best_key}", fontsize=14)
349
+ ax.set_xlabel("Predicted Speaker", fontsize=12)
350
+ ax.set_ylabel("True Speaker", fontsize=12)
351
+
352
+ plt.tight_layout()
353
+ cm_path = PLOTS_DIR / f"confusion_matrix_{best_key.replace(' ', '_').replace('(', '').replace(')', '')}.png"
354
+ fig.savefig(cm_path, dpi=150)
355
+ plt.close(fig)
356
+ print(f" Saved: {cm_path}")
357
+
358
+ # 8c. F1 Score bar chart
359
+ fig, ax = plt.subplots(figsize=(12, 7))
360
+ for i, dim in enumerate(["192", "100", "50"]):
361
+ f1s = [results.get(f"{m}_{dim}", {}).get("f1_macro", 0) for m in model_defs]
362
+ ax.bar(x + i * width, f1s, width, label=f"PCA ({dim})", color=colors[i])
363
+ ax.set_xlabel("Model", fontsize=13)
364
+ ax.set_ylabel("Macro F1 Score", fontsize=13)
365
+ ax.set_title("Macro F1 Score by Model and PCA Dimensionality", fontsize=15)
366
+ ax.set_xticks(x + width)
367
+ ax.set_xticklabels(list(model_defs.keys()), rotation=15, ha="right")
368
+ ax.legend()
369
+ ax.grid(axis="y", linestyle="--", alpha=0.4)
370
+ plt.tight_layout()
371
+ f1_bar_path = PLOTS_DIR / "f1_comparison_bar.png"
372
+ fig.savefig(f1_bar_path, dpi=150)
373
+ plt.close(fig)
374
+ print(f" Saved: {f1_bar_path}")
375
+
376
+ # ============================================================
377
+ # Summary
378
+ # ============================================================
379
+ print("\n" + "=" * 60)
380
+ print("PIPELINE COMPLETE")
381
+ print("=" * 60)
382
+ print(f"\nResults directory: {OUTPUT_DIR.resolve()}")
383
+ print(f" Models: {MODELS_DIR.resolve()}")
384
+ print(f" Plots: {PLOTS_DIR.resolve()}")
385
+ print(f"\nTotal models saved: {len(results)}")
386
+ print(f"\nTop 5 results by accuracy:")
387
+ sorted_results = sorted(results.items(), key=lambda x: x[1]["accuracy"], reverse=True)
388
+ for key, val in sorted_results[:5]:
389
+ print(f" {key:40s} acc={val['accuracy']:.4f} f1={val['f1_macro']:.4f}")
390
+
391
+ print("\nDone!")