Satyawan1 commited on
Commit
ede2ce3
Β·
verified Β·
1 Parent(s): 425c0e6

Upload train_eeg_deep_analysis.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_eeg_deep_analysis.py +691 -0
train_eeg_deep_analysis.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ EEG Deep Analysis: Spectral Connectivity + Graph Metrics + Temporal CNN Features
4
+ ================================================================================
5
+ Advanced 3-way (AD vs FTD vs Control) and binary (AD vs non-AD) classification
6
+ using graph-theoretic features from coherence networks, Hjorth parameters,
7
+ envelope statistics, and ensemble learning.
8
+
9
+ Dataset: OpenNeuro ds004504 (88 subjects: 36 AD, 23 FTD, 29 Control)
10
+ Author: Satyawan Singh β€” Infonova Solutions
11
+ """
12
+
13
+ import os
14
+ import json
15
+ import time
16
+ import pickle
17
+ import warnings
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ warnings.filterwarnings('ignore')
22
+
23
+ # NumPy 2.0 compat: np.trapz -> np.trapezoid
24
+ if not hasattr(np, 'trapz'):
25
+ np.trapz = np.trapezoid
26
+
27
+ import mne
28
+ from scipy import signal
29
+ from scipy.stats import kurtosis, skew
30
+ from scipy.ndimage import uniform_filter1d
31
+
32
+ from sklearn.model_selection import StratifiedKFold, cross_val_predict
33
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
34
+ from sklearn.ensemble import (
35
+ GradientBoostingClassifier,
36
+ RandomForestClassifier,
37
+ VotingClassifier,
38
+ )
39
+ from sklearn.svm import SVC
40
+ from sklearn.metrics import (
41
+ classification_report,
42
+ confusion_matrix,
43
+ roc_auc_score,
44
+ accuracy_score,
45
+ )
46
+ from sklearn.feature_selection import SelectKBest, f_classif
47
+ from sklearn.pipeline import Pipeline
48
+
49
+ # ── Paths ──
50
+ BASE = '/Users/satyawansingh/Documents/alzheimer-research-complete/data/openneuro_ad_eeg'
51
+ OUTPUT_DIR = '/Users/satyawansingh/Documents/alzheimer-research-complete/models/eeg_deep_analysis'
52
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
53
+
54
+ CHANNELS = [
55
+ 'Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4',
56
+ 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz',
57
+ ]
58
+ N_CH = len(CHANNELS)
59
+
60
+ BANDS = {
61
+ 'theta': (4, 8),
62
+ 'alpha': (8, 13),
63
+ 'beta': (13, 30),
64
+ 'gamma': (30, 45),
65
+ }
66
+
67
+ # ══════════════════════════════════════════════════════════════
68
+ # FEATURE EXTRACTION FUNCTIONS
69
+ # ══════════════════════════════════════════════════════════════
70
+
71
+ def compute_coherence_matrix(data, sfreq, band):
72
+ """Compute pairwise coherence matrix for a given frequency band."""
73
+ fmin, fmax = band
74
+ n_ch = min(data.shape[0], N_CH)
75
+ coh_matrix = np.zeros((n_ch, n_ch))
76
+
77
+ for i in range(n_ch):
78
+ for j in range(i, n_ch):
79
+ if i == j:
80
+ coh_matrix[i, j] = 1.0
81
+ continue
82
+ freqs, coh = signal.coherence(
83
+ data[i], data[j], fs=sfreq,
84
+ nperseg=min(1024, len(data[i])),
85
+ )
86
+ band_mask = (freqs >= fmin) & (freqs <= fmax)
87
+ if band_mask.sum() > 0:
88
+ val = np.mean(coh[band_mask])
89
+ else:
90
+ val = 0.0
91
+ coh_matrix[i, j] = val
92
+ coh_matrix[j, i] = val
93
+
94
+ return coh_matrix
95
+
96
+
97
+ def graph_metrics_from_matrix(adj, band_name):
98
+ """
99
+ Extract graph-theoretic metrics from a coherence adjacency matrix.
100
+ Uses only numpy/scipy β€” no networkx dependency.
101
+ """
102
+ features = {}
103
+ n = adj.shape[0]
104
+
105
+ # Threshold to create binary graph (top 30% connections)
106
+ upper = adj[np.triu_indices(n, k=1)]
107
+ if len(upper) == 0 or np.std(upper) < 1e-12:
108
+ # Degenerate matrix β€” return zeros
109
+ for key in [
110
+ 'mean_coh', 'std_coh', 'global_efficiency', 'clustering_coeff',
111
+ 'char_path_length', 'small_worldness', 'modularity_approx',
112
+ 'degree_entropy', 'hub_score_max', 'hub_score_std',
113
+ 'assortativity_approx',
114
+ ]:
115
+ features[f'{band_name}_{key}'] = 0.0
116
+ return features
117
+
118
+ threshold = np.percentile(upper, 70)
119
+ binary = (adj >= threshold).astype(float)
120
+ np.fill_diagonal(binary, 0)
121
+
122
+ # --- Weighted metrics ---
123
+ features[f'{band_name}_mean_coh'] = np.mean(upper)
124
+ features[f'{band_name}_std_coh'] = np.std(upper)
125
+
126
+ # --- Degree & hub scores (weighted) ---
127
+ strength = adj.sum(axis=1) - 1 # subtract self-connection
128
+ degree = binary.sum(axis=1)
129
+ max_degree = degree.max() if degree.max() > 0 else 1
130
+ hub_scores = degree / max_degree
131
+ features[f'{band_name}_hub_score_max'] = hub_scores.max()
132
+ features[f'{band_name}_hub_score_std'] = hub_scores.std()
133
+
134
+ # Degree entropy
135
+ degree_norm = degree / (degree.sum() + 1e-10)
136
+ degree_norm = degree_norm[degree_norm > 0]
137
+ features[f'{band_name}_degree_entropy'] = -np.sum(degree_norm * np.log2(degree_norm + 1e-15))
138
+
139
+ # --- Clustering coefficient (binary) ---
140
+ # C_i = 2 * triangles_i / (k_i * (k_i - 1))
141
+ triangles = np.diag(binary @ binary @ binary) / 2
142
+ k = degree
143
+ denom = k * (k - 1)
144
+ denom[denom == 0] = 1
145
+ cc = 2 * triangles / denom
146
+ features[f'{band_name}_clustering_coeff'] = np.mean(cc)
147
+
148
+ # --- Characteristic path length via Floyd-Warshall on binary ---
149
+ # Distance matrix: 1/weight for connected, inf for disconnected
150
+ dist = np.full((n, n), np.inf)
151
+ np.fill_diagonal(dist, 0)
152
+ connected = binary > 0
153
+ # Use inverse coherence as distance for weighted path
154
+ dist[connected] = 1.0 / (adj[connected] + 1e-10)
155
+
156
+ # Floyd-Warshall
157
+ for k_node in range(n):
158
+ new_dist = dist[:, k_node, None] + dist[None, k_node, :]
159
+ dist = np.minimum(dist, new_dist)
160
+
161
+ finite_dists = dist[np.triu_indices(n, k=1)]
162
+ finite_dists = finite_dists[np.isfinite(finite_dists)]
163
+ if len(finite_dists) > 0:
164
+ cpl = np.mean(finite_dists)
165
+ else:
166
+ cpl = np.inf
167
+
168
+ features[f'{band_name}_char_path_length'] = cpl if np.isfinite(cpl) else 100.0
169
+
170
+ # --- Global efficiency: mean of 1/d_ij ---
171
+ inv_dist = 1.0 / (dist + 1e-10)
172
+ np.fill_diagonal(inv_dist, 0)
173
+ features[f'{band_name}_global_efficiency'] = inv_dist.sum() / (n * (n - 1))
174
+
175
+ # --- Small-worldness approximation ---
176
+ # sigma = (C / C_rand) / (L / L_rand)
177
+ # For Erdos-Renyi with same density: C_rand ~ p, L_rand ~ ln(n) / ln(k_mean)
178
+ p = binary.sum() / (n * (n - 1))
179
+ c_rand = max(p, 1e-10)
180
+ k_mean = degree.mean()
181
+ l_rand = np.log(n) / (np.log(k_mean + 1) + 1e-10) if k_mean > 1 else 10.0
182
+
183
+ C_real = features[f'{band_name}_clustering_coeff']
184
+ L_real = features[f'{band_name}_char_path_length']
185
+
186
+ sigma = (C_real / (c_rand + 1e-10)) / (L_real / (l_rand + 1e-10) + 1e-10)
187
+ features[f'{band_name}_small_worldness'] = sigma
188
+
189
+ # --- Modularity approximation (spectral bisection) ---
190
+ # Q = 1/(2m) * sum( A_ij - k_i*k_j/(2m) ) * delta(c_i, c_j)
191
+ m = binary.sum() / 2
192
+ if m > 0:
193
+ B = binary - np.outer(degree, degree) / (2 * m + 1e-10)
194
+ eigvals, eigvecs = np.linalg.eigh(B)
195
+ # Partition based on sign of leading eigenvector
196
+ partition = (eigvecs[:, -1] > 0).astype(int)
197
+ same_comm = np.outer(partition, partition) + np.outer(1 - partition, 1 - partition)
198
+ Q = np.sum(B * same_comm) / (4 * m + 1e-10)
199
+ features[f'{band_name}_modularity_approx'] = Q
200
+ else:
201
+ features[f'{band_name}_modularity_approx'] = 0.0
202
+
203
+ # --- Assortativity approximation ---
204
+ # Correlation of degrees at each end of edges
205
+ edges_i, edges_j = np.where(np.triu(binary, k=1) > 0)
206
+ if len(edges_i) > 2:
207
+ d_i = degree[edges_i]
208
+ d_j = degree[edges_j]
209
+ if np.std(d_i) > 0 and np.std(d_j) > 0:
210
+ features[f'{band_name}_assortativity_approx'] = np.corrcoef(d_i, d_j)[0, 1]
211
+ else:
212
+ features[f'{band_name}_assortativity_approx'] = 0.0
213
+ else:
214
+ features[f'{band_name}_assortativity_approx'] = 0.0
215
+
216
+ return features
217
+
218
+
219
+ def compute_temporal_cnn_features(data, sfreq):
220
+ """
221
+ Extract temporal / signal-morphology features per channel:
222
+ - Hjorth mobility & complexity
223
+ - Signal envelope statistics
224
+ - Zero-crossing rate
225
+ - Line length
226
+ - Higuchi fractal dimension approximation
227
+ """
228
+ features = {}
229
+ n_ch = min(data.shape[0], N_CH)
230
+
231
+ for ch_idx in range(n_ch):
232
+ ch = CHANNELS[ch_idx]
233
+ x = data[ch_idx]
234
+
235
+ # --- Basic stats ---
236
+ features[f'{ch}_mean'] = np.mean(x)
237
+ features[f'{ch}_std'] = np.std(x)
238
+ features[f'{ch}_kurtosis'] = kurtosis(x)
239
+ features[f'{ch}_skewness'] = skew(x)
240
+ features[f'{ch}_rms'] = np.sqrt(np.mean(x ** 2))
241
+
242
+ # --- Hjorth parameters ---
243
+ diff1 = np.diff(x)
244
+ diff2 = np.diff(diff1)
245
+ activity = np.var(x)
246
+ mobility = np.sqrt(np.var(diff1) / (activity + 1e-10))
247
+ complexity = np.sqrt(np.var(diff2) / (np.var(diff1) + 1e-10)) / (mobility + 1e-10)
248
+ features[f'{ch}_hjorth_activity'] = activity
249
+ features[f'{ch}_hjorth_mobility'] = mobility
250
+ features[f'{ch}_hjorth_complexity'] = complexity
251
+
252
+ # --- Zero-crossing rate ---
253
+ zcr = np.sum(np.diff(np.sign(x)) != 0) / len(x)
254
+ features[f'{ch}_zcr'] = zcr
255
+
256
+ # --- Signal envelope (analytic signal) ---
257
+ analytic = signal.hilbert(x)
258
+ envelope = np.abs(analytic)
259
+ features[f'{ch}_env_mean'] = np.mean(envelope)
260
+ features[f'{ch}_env_std'] = np.std(envelope)
261
+ features[f'{ch}_env_skew'] = skew(envelope)
262
+ features[f'{ch}_env_kurtosis'] = kurtosis(envelope)
263
+
264
+ # --- Line length (sum of absolute differences) ---
265
+ features[f'{ch}_line_length'] = np.mean(np.abs(diff1))
266
+
267
+ # --- Higuchi fractal dimension (fast approximation, k_max=10) ---
268
+ k_max = 10
269
+ N_pts = len(x)
270
+ lk = []
271
+ for k in range(1, k_max + 1):
272
+ lengths = []
273
+ for m in range(1, k + 1):
274
+ idx = np.arange(m - 1, N_pts, k)
275
+ if len(idx) < 2:
276
+ continue
277
+ seg = x[idx]
278
+ L_m = np.sum(np.abs(np.diff(seg))) * (N_pts - 1) / (k * len(seg) * k)
279
+ lengths.append(L_m)
280
+ if lengths:
281
+ lk.append(np.mean(lengths))
282
+ if len(lk) > 2:
283
+ ks = np.arange(1, len(lk) + 1)
284
+ log_k = np.log(ks)
285
+ log_lk = np.log(np.array(lk) + 1e-15)
286
+ # Linear fit slope = fractal dimension
287
+ slope, _ = np.polyfit(log_k, log_lk, 1)
288
+ features[f'{ch}_hfd'] = -slope
289
+ else:
290
+ features[f'{ch}_hfd'] = 0.0
291
+
292
+ return features
293
+
294
+
295
+ def compute_band_power_features(data, sfreq):
296
+ """Compute per-channel PSD band powers and key spectral ratios."""
297
+ features = {}
298
+ n_ch = min(data.shape[0], N_CH)
299
+
300
+ for ch_idx in range(n_ch):
301
+ ch = CHANNELS[ch_idx]
302
+ x = data[ch_idx]
303
+ freqs, psd = signal.welch(x, fs=sfreq, nperseg=min(2048, len(x)))
304
+ total = np.trapz(psd, freqs) + 1e-10
305
+
306
+ for bname, (fmin, fmax) in BANDS.items():
307
+ mask = (freqs >= fmin) & (freqs <= fmax)
308
+ bp = np.trapz(psd[mask], freqs[mask])
309
+ features[f'{ch}_{bname}_rel'] = bp / total
310
+
311
+ # Delta band too
312
+ delta_mask = (freqs >= 0.5) & (freqs <= 4)
313
+ delta_power = np.trapz(psd[delta_mask], freqs[delta_mask])
314
+ features[f'{ch}_delta_rel'] = delta_power / total
315
+
316
+ # Spectral ratios
317
+ alpha_mask = (freqs >= 8) & (freqs <= 13)
318
+ theta_mask = (freqs >= 4) & (freqs <= 8)
319
+ alpha_p = np.trapz(psd[alpha_mask], freqs[alpha_mask])
320
+ theta_p = np.trapz(psd[theta_mask], freqs[theta_mask])
321
+ features[f'{ch}_theta_alpha_ratio'] = theta_p / (alpha_p + 1e-10)
322
+ features[f'{ch}_delta_alpha_ratio'] = delta_power / (alpha_p + 1e-10)
323
+
324
+ # Peak alpha frequency
325
+ alpha_freqs = freqs[alpha_mask]
326
+ alpha_psd = psd[alpha_mask]
327
+ if len(alpha_psd) > 0:
328
+ features[f'{ch}_peak_alpha_freq'] = alpha_freqs[np.argmax(alpha_psd)]
329
+ else:
330
+ features[f'{ch}_peak_alpha_freq'] = 0
331
+
332
+ # Spectral entropy
333
+ psd_norm = psd / (psd.sum() + 1e-10)
334
+ psd_pos = psd_norm[psd_norm > 0]
335
+ features[f'{ch}_spectral_entropy'] = -np.sum(psd_pos * np.log2(psd_pos))
336
+
337
+ return features
338
+
339
+
340
+ # ══════════════════════════════════════════════════════════════
341
+ # STEP 1: Load participants
342
+ # ══════════════════════════════════════════════════════════════
343
+ print("=" * 70)
344
+ print(" EEG DEEP ANALYSIS: Connectivity Graphs + Temporal + Ensemble")
345
+ print("=" * 70)
346
+
347
+ participants = pd.read_csv(os.path.join(BASE, 'participants.tsv'), sep='\t')
348
+ print(f"\nParticipants: {len(participants)}")
349
+ print(f"Groups: {dict(participants['Group'].value_counts())}")
350
+
351
+ label_map = {'A': 0, 'C': 1, 'F': 2}
352
+ label_names = {0: 'AD', 1: 'Control', 2: 'FTD'}
353
+
354
+ # ══════════════════════════════════════════════════════════════
355
+ # STEP 2: Feature extraction
356
+ # ══════════════════════════════════════════════════════════════
357
+ print(f"\n{'=' * 70}")
358
+ print(" STEP 2: Extracting features (coherence graphs + temporal + PSD)")
359
+ print("=" * 70)
360
+
361
+ all_features = []
362
+ all_labels = []
363
+ all_subjects = []
364
+ failed = []
365
+
366
+ t0 = time.time()
367
+
368
+ for idx, row in participants.iterrows():
369
+ sub_id = row['participant_id']
370
+ group = row['Group']
371
+ label = label_map[group]
372
+
373
+ eeg_file = os.path.join(BASE, sub_id, 'eeg', f'{sub_id}_task-eyesclosed_eeg.set')
374
+ if not os.path.exists(eeg_file):
375
+ failed.append(sub_id)
376
+ continue
377
+
378
+ try:
379
+ raw = mne.io.read_raw_eeglab(eeg_file, preload=True, verbose=False)
380
+ raw.filter(0.5, 45, verbose=False)
381
+ data = raw.get_data()
382
+ sfreq = raw.info['sfreq']
383
+
384
+ feats = {}
385
+
386
+ # 1) Spectral connectivity graph metrics per band
387
+ for band_name, band_range in BANDS.items():
388
+ coh_mat = compute_coherence_matrix(data, sfreq, band_range)
389
+ graph_feats = graph_metrics_from_matrix(coh_mat, band_name)
390
+ feats.update(graph_feats)
391
+
392
+ # Also store upper-triangle coherence stats per band
393
+ upper = coh_mat[np.triu_indices(N_CH, k=1)]
394
+ feats[f'{band_name}_coh_median'] = np.median(upper)
395
+ feats[f'{band_name}_coh_q25'] = np.percentile(upper, 25)
396
+ feats[f'{band_name}_coh_q75'] = np.percentile(upper, 75)
397
+
398
+ # 2) Temporal / signal-morphology features
399
+ feats.update(compute_temporal_cnn_features(data, sfreq))
400
+
401
+ # 3) Band power features
402
+ feats.update(compute_band_power_features(data, sfreq))
403
+
404
+ # 4) Demographics
405
+ feats['age'] = row['Age']
406
+ feats['gender'] = 1 if row['Gender'] == 'M' else 0
407
+ feats['mmse'] = row['MMSE']
408
+
409
+ all_features.append(feats)
410
+ all_labels.append(label)
411
+ all_subjects.append(sub_id)
412
+
413
+ elapsed = time.time() - t0
414
+ print(f" [{idx+1:2d}/{len(participants)}] {sub_id} [{label_names[label]:>7s}] "
415
+ f"β€” {len(feats)} features ({elapsed:.0f}s)")
416
+
417
+ except Exception as e:
418
+ print(f" [{idx+1:2d}/{len(participants)}] {sub_id} FAILED: {e}")
419
+ failed.append(sub_id)
420
+
421
+ print(f"\nExtracted: {len(all_features)} subjects | Failed: {len(failed)}")
422
+ print(f"Total time: {time.time() - t0:.0f}s")
423
+
424
+ # Convert to matrix
425
+ X = pd.DataFrame(all_features).fillna(0)
426
+ y = np.array(all_labels)
427
+
428
+ # Replace inf with large finite value
429
+ X = X.replace([np.inf, -np.inf], 0)
430
+
431
+ print(f"Feature matrix: {X.shape}")
432
+ print(f"Labels: AD={sum(y==0)}, Control={sum(y==1)}, FTD={sum(y==2)}")
433
+
434
+ # Save raw features
435
+ X.to_csv(os.path.join(OUTPUT_DIR, 'deep_features.csv'), index=False)
436
+ np.save(os.path.join(OUTPUT_DIR, 'labels.npy'), y)
437
+
438
+ # ══════════════════════════════════════════════════════════════
439
+ # STEP 3: Feature selection & scaling
440
+ # ══════════════════════════════════════════════════════════════
441
+ print(f"\n{'=' * 70}")
442
+ print(" STEP 3: Feature selection")
443
+ print("=" * 70)
444
+
445
+ k_features = min(120, X.shape[1])
446
+ selector = SelectKBest(f_classif, k=k_features)
447
+ X_selected = selector.fit_transform(X, y)
448
+
449
+ selected_mask = selector.get_support()
450
+ selected_features = X.columns[selected_mask].tolist()
451
+ print(f"Selected {len(selected_features)} / {X.shape[1]} features")
452
+
453
+ # Top 25 features by F-score
454
+ scores = selector.scores_[selected_mask]
455
+ top_idx = np.argsort(scores)[::-1][:25]
456
+ print("\nTop 25 most discriminative features:")
457
+ for i, ix in enumerate(top_idx):
458
+ print(f" {i+1:2d}. {selected_features[ix]:45s} F={scores[ix]:.1f}")
459
+
460
+ scaler = StandardScaler()
461
+ X_scaled = scaler.fit_transform(X_selected)
462
+
463
+ # ══════════════════════════════════════════════════════════════
464
+ # STEP 4: 3-way classification (AD vs FTD vs Control)
465
+ # ══════════════════════════════════════════════════════════════
466
+ print(f"\n{'=' * 70}")
467
+ print(" STEP 4: 3-way classification (AD vs FTD vs Control)")
468
+ print("=" * 70)
469
+
470
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
471
+
472
+ gb = GradientBoostingClassifier(
473
+ n_estimators=300, max_depth=4, learning_rate=0.05,
474
+ subsample=0.8, random_state=42,
475
+ )
476
+ rf = RandomForestClassifier(
477
+ n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42,
478
+ )
479
+ svm = SVC(
480
+ kernel='rbf', C=10, gamma='scale', probability=True, random_state=42,
481
+ )
482
+
483
+ models_3c = {
484
+ 'GradientBoosting': gb,
485
+ 'RandomForest': rf,
486
+ 'SVM_RBF': svm,
487
+ }
488
+
489
+ # Also build voting ensemble
490
+ voting = VotingClassifier(
491
+ estimators=[('gb', gb), ('rf', rf), ('svm', svm)],
492
+ voting='soft',
493
+ )
494
+ models_3c['VotingEnsemble'] = voting
495
+
496
+ results_3c = {}
497
+
498
+ for name, model in models_3c.items():
499
+ y_pred = cross_val_predict(model, X_scaled, y, cv=cv)
500
+ y_prob = cross_val_predict(model, X_scaled, y, cv=cv, method='predict_proba')
501
+
502
+ acc = accuracy_score(y, y_pred)
503
+ # One-vs-rest AUC
504
+ try:
505
+ auc = roc_auc_score(y, y_prob, multi_class='ovr', average='weighted')
506
+ except Exception:
507
+ auc = 0.0
508
+
509
+ results_3c[name] = {'accuracy': acc, 'auc': auc, 'y_pred': y_pred, 'y_prob': y_prob}
510
+
511
+ print(f"\n{'─' * 50}")
512
+ print(f" {name}: Accuracy = {acc:.1%} | AUC(OvR) = {auc:.3f}")
513
+ print(f"{'─' * 50}")
514
+ print(classification_report(y, y_pred, target_names=['AD', 'Control', 'FTD']))
515
+ print("Confusion matrix:")
516
+ cm = confusion_matrix(y, y_pred)
517
+ print(f" {'':>10s} pred_AD pred_Ctrl pred_FTD")
518
+ for i, lbl in enumerate(['AD', 'Control', 'FTD']):
519
+ print(f" {lbl:>10s} {cm[i,0]:7d} {cm[i,1]:9d} {cm[i,2]:8d}")
520
+
521
+ best_3c = max(results_3c, key=lambda k: results_3c[k]['accuracy'])
522
+ print(f"\n>>> Best 3-class model: {best_3c} "
523
+ f"(Acc={results_3c[best_3c]['accuracy']:.1%}, "
524
+ f"AUC={results_3c[best_3c]['auc']:.3f})")
525
+
526
+ # ══════════════════════════════════════════════════════════════
527
+ # STEP 5: Binary classification (AD vs non-AD)
528
+ # ══════════════════════════════════════════════════════════════
529
+ print(f"\n{'=' * 70}")
530
+ print(" STEP 5: Binary classification (AD vs non-AD)")
531
+ print("=" * 70)
532
+
533
+ y_binary = (y == 0).astype(int) # AD=1, non-AD=0
534
+
535
+ gb_b = GradientBoostingClassifier(
536
+ n_estimators=300, max_depth=4, learning_rate=0.05,
537
+ subsample=0.8, random_state=42,
538
+ )
539
+ rf_b = RandomForestClassifier(
540
+ n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42,
541
+ )
542
+ svm_b = SVC(
543
+ kernel='rbf', C=10, gamma='scale', probability=True, random_state=42,
544
+ )
545
+ voting_b = VotingClassifier(
546
+ estimators=[('gb', gb_b), ('rf', rf_b), ('svm', svm_b)],
547
+ voting='soft',
548
+ )
549
+
550
+ models_bin = {
551
+ 'GradientBoosting': gb_b,
552
+ 'RandomForest': rf_b,
553
+ 'SVM_RBF': svm_b,
554
+ 'VotingEnsemble': voting_b,
555
+ }
556
+
557
+ cv_bin = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
558
+ results_bin = {}
559
+
560
+ for name, model in models_bin.items():
561
+ y_pred = cross_val_predict(model, X_scaled, y_binary, cv=cv_bin)
562
+ y_prob = cross_val_predict(model, X_scaled, y_binary, cv=cv_bin, method='predict_proba')
563
+
564
+ acc = accuracy_score(y_binary, y_pred)
565
+ auc = roc_auc_score(y_binary, y_prob[:, 1])
566
+ sens = np.sum((y_pred == 1) & (y_binary == 1)) / (np.sum(y_binary == 1) + 1e-10)
567
+ spec = np.sum((y_pred == 0) & (y_binary == 0)) / (np.sum(y_binary == 0) + 1e-10)
568
+
569
+ results_bin[name] = {'accuracy': acc, 'auc': auc, 'sensitivity': sens, 'specificity': spec}
570
+
571
+ print(f"\n{'─' * 50}")
572
+ print(f" {name}: Acc={acc:.1%} AUC={auc:.3f} Sens={sens:.1%} Spec={spec:.1%}")
573
+ print(f"{'─' * 50}")
574
+ print(classification_report(y_binary, y_pred, target_names=['non-AD', 'AD']))
575
+
576
+ best_bin = max(results_bin, key=lambda k: results_bin[k]['auc'])
577
+ print(f"\n>>> Best binary model: {best_bin} "
578
+ f"(AUC={results_bin[best_bin]['auc']:.3f}, "
579
+ f"Acc={results_bin[best_bin]['accuracy']:.1%})")
580
+
581
+ # ══════════════════════════════════════════════════════════════
582
+ # STEP 6: Train final models on full data & save
583
+ # ══════════════════════════════════════════════════════════════
584
+ print(f"\n{'=' * 70}")
585
+ print(" STEP 6: Training final models & saving artifacts")
586
+ print("=" * 70)
587
+
588
+ # Final 3-class
589
+ final_3c = VotingClassifier(
590
+ estimators=[
591
+ ('gb', GradientBoostingClassifier(
592
+ n_estimators=300, max_depth=4, learning_rate=0.05,
593
+ subsample=0.8, random_state=42)),
594
+ ('rf', RandomForestClassifier(
595
+ n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42)),
596
+ ('svm', SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42)),
597
+ ],
598
+ voting='soft',
599
+ )
600
+ final_3c.fit(X_scaled, y)
601
+
602
+ # Final binary
603
+ final_bin = VotingClassifier(
604
+ estimators=[
605
+ ('gb', GradientBoostingClassifier(
606
+ n_estimators=300, max_depth=4, learning_rate=0.05,
607
+ subsample=0.8, random_state=42)),
608
+ ('rf', RandomForestClassifier(
609
+ n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42)),
610
+ ('svm', SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42)),
611
+ ],
612
+ voting='soft',
613
+ )
614
+ final_bin.fit(X_scaled, y_binary)
615
+
616
+ # Feature importance from the GradientBoosting inside the ensemble
617
+ gb_inside = final_3c.named_estimators_['gb']
618
+ importances = gb_inside.feature_importances_
619
+ top_fi = np.argsort(importances)[::-1][:20]
620
+
621
+ print("\nTop 20 feature importances (GradientBoosting, 3-class):")
622
+ for i, ix in enumerate(top_fi):
623
+ print(f" {i+1:2d}. {selected_features[ix]:45s} imp={importances[ix]:.4f}")
624
+
625
+ # Save feature importance
626
+ fi_df = pd.DataFrame({
627
+ 'feature': selected_features,
628
+ 'importance': importances,
629
+ }).sort_values('importance', ascending=False)
630
+ fi_df.to_csv(os.path.join(OUTPUT_DIR, 'feature_importance.csv'), index=False)
631
+
632
+ # Save models
633
+ artifacts = {
634
+ 'model_3class': final_3c,
635
+ 'model_binary': final_bin,
636
+ 'scaler': scaler,
637
+ 'selector': selector,
638
+ 'feature_names': list(X.columns),
639
+ 'selected_features': selected_features,
640
+ 'label_names_3class': {0: 'AD', 1: 'Control', 2: 'FTD'},
641
+ 'label_names_binary': {0: 'non-AD', 1: 'AD'},
642
+ 'channels': CHANNELS,
643
+ 'bands': BANDS,
644
+ 'results_3class': {k: {kk: vv for kk, vv in v.items() if kk not in ('y_pred', 'y_prob')}
645
+ for k, v in results_3c.items()},
646
+ 'results_binary': results_bin,
647
+ }
648
+
649
+ model_path = os.path.join(OUTPUT_DIR, 'eeg_deep_analysis.pkl')
650
+ with open(model_path, 'wb') as f:
651
+ pickle.dump(artifacts, f)
652
+ print(f"\nModel saved: {model_path} ({os.path.getsize(model_path)/1e6:.1f} MB)")
653
+
654
+ # Save results summary as JSON
655
+ summary = {
656
+ 'dataset': 'OpenNeuro ds004504',
657
+ 'n_subjects': len(all_features),
658
+ 'n_failed': len(failed),
659
+ 'n_features_total': X.shape[1],
660
+ 'n_features_selected': len(selected_features),
661
+ 'classification_3way': {
662
+ name: {
663
+ 'accuracy': float(f"{v['accuracy']:.4f}"),
664
+ 'auc_ovr': float(f"{v['auc']:.4f}"),
665
+ }
666
+ for name, v in results_3c.items()
667
+ },
668
+ 'classification_binary_AD_vs_nonAD': {
669
+ name: {k: float(f"{vv:.4f}") for k, vv in v.items()}
670
+ for name, v in results_bin.items()
671
+ },
672
+ 'best_3class_model': best_3c,
673
+ 'best_binary_model': best_bin,
674
+ }
675
+
676
+ with open(os.path.join(OUTPUT_DIR, 'results_summary.json'), 'w') as f:
677
+ json.dump(summary, f, indent=2)
678
+
679
+ # ══════════════════════════════════════════════════════════════
680
+ # FINAL SUMMARY
681
+ # ══════════════════════════════════════════════════════════════
682
+ print(f"\n{'=' * 70}")
683
+ print(" EEG DEEP ANALYSIS β€” COMPLETE")
684
+ print("=" * 70)
685
+ print(f" Subjects: {len(all_features)} ({sum(y==0)} AD, {sum(y==1)} Ctrl, {sum(y==2)} FTD)")
686
+ print(f" Features: {X.shape[1]} total -> {len(selected_features)} selected")
687
+ print(f" 3-class best: {best_3c} Acc={results_3c[best_3c]['accuracy']:.1%} AUC={results_3c[best_3c]['auc']:.3f}")
688
+ print(f" Binary best: {best_bin} AUC={results_bin[best_bin]['auc']:.3f} Acc={results_bin[best_bin]['accuracy']:.1%}")
689
+ print(f" Output dir: {OUTPUT_DIR}")
690
+ print(f" Author: Satyawan Singh β€” Infonova Solutions")
691
+ print("=" * 70)