alzheimer-research-complete / train_eeg_deep_analysis.py
Satyawan1's picture
Upload train_eeg_deep_analysis.py with huggingface_hub
ede2ce3 verified
#!/usr/bin/env python3
"""
EEG Deep Analysis: Spectral Connectivity + Graph Metrics + Temporal CNN Features
================================================================================
Advanced 3-way (AD vs FTD vs Control) and binary (AD vs non-AD) classification
using graph-theoretic features from coherence networks, Hjorth parameters,
envelope statistics, and ensemble learning.
Dataset: OpenNeuro ds004504 (88 subjects: 36 AD, 23 FTD, 29 Control)
Author: Satyawan Singh β€” Infonova Solutions
"""
import os
import json
import time
import pickle
import warnings
import numpy as np
import pandas as pd
warnings.filterwarnings('ignore')
# NumPy 2.0 compat: np.trapz -> np.trapezoid
if not hasattr(np, 'trapz'):
np.trapz = np.trapezoid
import mne
from scipy import signal
from scipy.stats import kurtosis, skew
from scipy.ndimage import uniform_filter1d
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import (
GradientBoostingClassifier,
RandomForestClassifier,
VotingClassifier,
)
from sklearn.svm import SVC
from sklearn.metrics import (
classification_report,
confusion_matrix,
roc_auc_score,
accuracy_score,
)
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.pipeline import Pipeline
# ── Paths ──
BASE = '/Users/satyawansingh/Documents/alzheimer-research-complete/data/openneuro_ad_eeg'
OUTPUT_DIR = '/Users/satyawansingh/Documents/alzheimer-research-complete/models/eeg_deep_analysis'
os.makedirs(OUTPUT_DIR, exist_ok=True)
CHANNELS = [
'Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4',
'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz',
]
N_CH = len(CHANNELS)
BANDS = {
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 45),
}
# ══════════════════════════════════════════════════════════════
# FEATURE EXTRACTION FUNCTIONS
# ══════════════════════════════════════════════════════════════
def compute_coherence_matrix(data, sfreq, band):
"""Compute pairwise coherence matrix for a given frequency band."""
fmin, fmax = band
n_ch = min(data.shape[0], N_CH)
coh_matrix = np.zeros((n_ch, n_ch))
for i in range(n_ch):
for j in range(i, n_ch):
if i == j:
coh_matrix[i, j] = 1.0
continue
freqs, coh = signal.coherence(
data[i], data[j], fs=sfreq,
nperseg=min(1024, len(data[i])),
)
band_mask = (freqs >= fmin) & (freqs <= fmax)
if band_mask.sum() > 0:
val = np.mean(coh[band_mask])
else:
val = 0.0
coh_matrix[i, j] = val
coh_matrix[j, i] = val
return coh_matrix
def graph_metrics_from_matrix(adj, band_name):
"""
Extract graph-theoretic metrics from a coherence adjacency matrix.
Uses only numpy/scipy β€” no networkx dependency.
"""
features = {}
n = adj.shape[0]
# Threshold to create binary graph (top 30% connections)
upper = adj[np.triu_indices(n, k=1)]
if len(upper) == 0 or np.std(upper) < 1e-12:
# Degenerate matrix β€” return zeros
for key in [
'mean_coh', 'std_coh', 'global_efficiency', 'clustering_coeff',
'char_path_length', 'small_worldness', 'modularity_approx',
'degree_entropy', 'hub_score_max', 'hub_score_std',
'assortativity_approx',
]:
features[f'{band_name}_{key}'] = 0.0
return features
threshold = np.percentile(upper, 70)
binary = (adj >= threshold).astype(float)
np.fill_diagonal(binary, 0)
# --- Weighted metrics ---
features[f'{band_name}_mean_coh'] = np.mean(upper)
features[f'{band_name}_std_coh'] = np.std(upper)
# --- Degree & hub scores (weighted) ---
strength = adj.sum(axis=1) - 1 # subtract self-connection
degree = binary.sum(axis=1)
max_degree = degree.max() if degree.max() > 0 else 1
hub_scores = degree / max_degree
features[f'{band_name}_hub_score_max'] = hub_scores.max()
features[f'{band_name}_hub_score_std'] = hub_scores.std()
# Degree entropy
degree_norm = degree / (degree.sum() + 1e-10)
degree_norm = degree_norm[degree_norm > 0]
features[f'{band_name}_degree_entropy'] = -np.sum(degree_norm * np.log2(degree_norm + 1e-15))
# --- Clustering coefficient (binary) ---
# C_i = 2 * triangles_i / (k_i * (k_i - 1))
triangles = np.diag(binary @ binary @ binary) / 2
k = degree
denom = k * (k - 1)
denom[denom == 0] = 1
cc = 2 * triangles / denom
features[f'{band_name}_clustering_coeff'] = np.mean(cc)
# --- Characteristic path length via Floyd-Warshall on binary ---
# Distance matrix: 1/weight for connected, inf for disconnected
dist = np.full((n, n), np.inf)
np.fill_diagonal(dist, 0)
connected = binary > 0
# Use inverse coherence as distance for weighted path
dist[connected] = 1.0 / (adj[connected] + 1e-10)
# Floyd-Warshall
for k_node in range(n):
new_dist = dist[:, k_node, None] + dist[None, k_node, :]
dist = np.minimum(dist, new_dist)
finite_dists = dist[np.triu_indices(n, k=1)]
finite_dists = finite_dists[np.isfinite(finite_dists)]
if len(finite_dists) > 0:
cpl = np.mean(finite_dists)
else:
cpl = np.inf
features[f'{band_name}_char_path_length'] = cpl if np.isfinite(cpl) else 100.0
# --- Global efficiency: mean of 1/d_ij ---
inv_dist = 1.0 / (dist + 1e-10)
np.fill_diagonal(inv_dist, 0)
features[f'{band_name}_global_efficiency'] = inv_dist.sum() / (n * (n - 1))
# --- Small-worldness approximation ---
# sigma = (C / C_rand) / (L / L_rand)
# For Erdos-Renyi with same density: C_rand ~ p, L_rand ~ ln(n) / ln(k_mean)
p = binary.sum() / (n * (n - 1))
c_rand = max(p, 1e-10)
k_mean = degree.mean()
l_rand = np.log(n) / (np.log(k_mean + 1) + 1e-10) if k_mean > 1 else 10.0
C_real = features[f'{band_name}_clustering_coeff']
L_real = features[f'{band_name}_char_path_length']
sigma = (C_real / (c_rand + 1e-10)) / (L_real / (l_rand + 1e-10) + 1e-10)
features[f'{band_name}_small_worldness'] = sigma
# --- Modularity approximation (spectral bisection) ---
# Q = 1/(2m) * sum( A_ij - k_i*k_j/(2m) ) * delta(c_i, c_j)
m = binary.sum() / 2
if m > 0:
B = binary - np.outer(degree, degree) / (2 * m + 1e-10)
eigvals, eigvecs = np.linalg.eigh(B)
# Partition based on sign of leading eigenvector
partition = (eigvecs[:, -1] > 0).astype(int)
same_comm = np.outer(partition, partition) + np.outer(1 - partition, 1 - partition)
Q = np.sum(B * same_comm) / (4 * m + 1e-10)
features[f'{band_name}_modularity_approx'] = Q
else:
features[f'{band_name}_modularity_approx'] = 0.0
# --- Assortativity approximation ---
# Correlation of degrees at each end of edges
edges_i, edges_j = np.where(np.triu(binary, k=1) > 0)
if len(edges_i) > 2:
d_i = degree[edges_i]
d_j = degree[edges_j]
if np.std(d_i) > 0 and np.std(d_j) > 0:
features[f'{band_name}_assortativity_approx'] = np.corrcoef(d_i, d_j)[0, 1]
else:
features[f'{band_name}_assortativity_approx'] = 0.0
else:
features[f'{band_name}_assortativity_approx'] = 0.0
return features
def compute_temporal_cnn_features(data, sfreq):
"""
Extract temporal / signal-morphology features per channel:
- Hjorth mobility & complexity
- Signal envelope statistics
- Zero-crossing rate
- Line length
- Higuchi fractal dimension approximation
"""
features = {}
n_ch = min(data.shape[0], N_CH)
for ch_idx in range(n_ch):
ch = CHANNELS[ch_idx]
x = data[ch_idx]
# --- Basic stats ---
features[f'{ch}_mean'] = np.mean(x)
features[f'{ch}_std'] = np.std(x)
features[f'{ch}_kurtosis'] = kurtosis(x)
features[f'{ch}_skewness'] = skew(x)
features[f'{ch}_rms'] = np.sqrt(np.mean(x ** 2))
# --- Hjorth parameters ---
diff1 = np.diff(x)
diff2 = np.diff(diff1)
activity = np.var(x)
mobility = np.sqrt(np.var(diff1) / (activity + 1e-10))
complexity = np.sqrt(np.var(diff2) / (np.var(diff1) + 1e-10)) / (mobility + 1e-10)
features[f'{ch}_hjorth_activity'] = activity
features[f'{ch}_hjorth_mobility'] = mobility
features[f'{ch}_hjorth_complexity'] = complexity
# --- Zero-crossing rate ---
zcr = np.sum(np.diff(np.sign(x)) != 0) / len(x)
features[f'{ch}_zcr'] = zcr
# --- Signal envelope (analytic signal) ---
analytic = signal.hilbert(x)
envelope = np.abs(analytic)
features[f'{ch}_env_mean'] = np.mean(envelope)
features[f'{ch}_env_std'] = np.std(envelope)
features[f'{ch}_env_skew'] = skew(envelope)
features[f'{ch}_env_kurtosis'] = kurtosis(envelope)
# --- Line length (sum of absolute differences) ---
features[f'{ch}_line_length'] = np.mean(np.abs(diff1))
# --- Higuchi fractal dimension (fast approximation, k_max=10) ---
k_max = 10
N_pts = len(x)
lk = []
for k in range(1, k_max + 1):
lengths = []
for m in range(1, k + 1):
idx = np.arange(m - 1, N_pts, k)
if len(idx) < 2:
continue
seg = x[idx]
L_m = np.sum(np.abs(np.diff(seg))) * (N_pts - 1) / (k * len(seg) * k)
lengths.append(L_m)
if lengths:
lk.append(np.mean(lengths))
if len(lk) > 2:
ks = np.arange(1, len(lk) + 1)
log_k = np.log(ks)
log_lk = np.log(np.array(lk) + 1e-15)
# Linear fit slope = fractal dimension
slope, _ = np.polyfit(log_k, log_lk, 1)
features[f'{ch}_hfd'] = -slope
else:
features[f'{ch}_hfd'] = 0.0
return features
def compute_band_power_features(data, sfreq):
"""Compute per-channel PSD band powers and key spectral ratios."""
features = {}
n_ch = min(data.shape[0], N_CH)
for ch_idx in range(n_ch):
ch = CHANNELS[ch_idx]
x = data[ch_idx]
freqs, psd = signal.welch(x, fs=sfreq, nperseg=min(2048, len(x)))
total = np.trapz(psd, freqs) + 1e-10
for bname, (fmin, fmax) in BANDS.items():
mask = (freqs >= fmin) & (freqs <= fmax)
bp = np.trapz(psd[mask], freqs[mask])
features[f'{ch}_{bname}_rel'] = bp / total
# Delta band too
delta_mask = (freqs >= 0.5) & (freqs <= 4)
delta_power = np.trapz(psd[delta_mask], freqs[delta_mask])
features[f'{ch}_delta_rel'] = delta_power / total
# Spectral ratios
alpha_mask = (freqs >= 8) & (freqs <= 13)
theta_mask = (freqs >= 4) & (freqs <= 8)
alpha_p = np.trapz(psd[alpha_mask], freqs[alpha_mask])
theta_p = np.trapz(psd[theta_mask], freqs[theta_mask])
features[f'{ch}_theta_alpha_ratio'] = theta_p / (alpha_p + 1e-10)
features[f'{ch}_delta_alpha_ratio'] = delta_power / (alpha_p + 1e-10)
# Peak alpha frequency
alpha_freqs = freqs[alpha_mask]
alpha_psd = psd[alpha_mask]
if len(alpha_psd) > 0:
features[f'{ch}_peak_alpha_freq'] = alpha_freqs[np.argmax(alpha_psd)]
else:
features[f'{ch}_peak_alpha_freq'] = 0
# Spectral entropy
psd_norm = psd / (psd.sum() + 1e-10)
psd_pos = psd_norm[psd_norm > 0]
features[f'{ch}_spectral_entropy'] = -np.sum(psd_pos * np.log2(psd_pos))
return features
# ══════════════════════════════════════════════════════════════
# STEP 1: Load participants
# ══════════════════════════════════════════════════════════════
print("=" * 70)
print(" EEG DEEP ANALYSIS: Connectivity Graphs + Temporal + Ensemble")
print("=" * 70)
participants = pd.read_csv(os.path.join(BASE, 'participants.tsv'), sep='\t')
print(f"\nParticipants: {len(participants)}")
print(f"Groups: {dict(participants['Group'].value_counts())}")
label_map = {'A': 0, 'C': 1, 'F': 2}
label_names = {0: 'AD', 1: 'Control', 2: 'FTD'}
# ══════════════════════════════════════════════════════════════
# STEP 2: Feature extraction
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" STEP 2: Extracting features (coherence graphs + temporal + PSD)")
print("=" * 70)
all_features = []
all_labels = []
all_subjects = []
failed = []
t0 = time.time()
for idx, row in participants.iterrows():
sub_id = row['participant_id']
group = row['Group']
label = label_map[group]
eeg_file = os.path.join(BASE, sub_id, 'eeg', f'{sub_id}_task-eyesclosed_eeg.set')
if not os.path.exists(eeg_file):
failed.append(sub_id)
continue
try:
raw = mne.io.read_raw_eeglab(eeg_file, preload=True, verbose=False)
raw.filter(0.5, 45, verbose=False)
data = raw.get_data()
sfreq = raw.info['sfreq']
feats = {}
# 1) Spectral connectivity graph metrics per band
for band_name, band_range in BANDS.items():
coh_mat = compute_coherence_matrix(data, sfreq, band_range)
graph_feats = graph_metrics_from_matrix(coh_mat, band_name)
feats.update(graph_feats)
# Also store upper-triangle coherence stats per band
upper = coh_mat[np.triu_indices(N_CH, k=1)]
feats[f'{band_name}_coh_median'] = np.median(upper)
feats[f'{band_name}_coh_q25'] = np.percentile(upper, 25)
feats[f'{band_name}_coh_q75'] = np.percentile(upper, 75)
# 2) Temporal / signal-morphology features
feats.update(compute_temporal_cnn_features(data, sfreq))
# 3) Band power features
feats.update(compute_band_power_features(data, sfreq))
# 4) Demographics
feats['age'] = row['Age']
feats['gender'] = 1 if row['Gender'] == 'M' else 0
feats['mmse'] = row['MMSE']
all_features.append(feats)
all_labels.append(label)
all_subjects.append(sub_id)
elapsed = time.time() - t0
print(f" [{idx+1:2d}/{len(participants)}] {sub_id} [{label_names[label]:>7s}] "
f"β€” {len(feats)} features ({elapsed:.0f}s)")
except Exception as e:
print(f" [{idx+1:2d}/{len(participants)}] {sub_id} FAILED: {e}")
failed.append(sub_id)
print(f"\nExtracted: {len(all_features)} subjects | Failed: {len(failed)}")
print(f"Total time: {time.time() - t0:.0f}s")
# Convert to matrix
X = pd.DataFrame(all_features).fillna(0)
y = np.array(all_labels)
# Replace inf with large finite value
X = X.replace([np.inf, -np.inf], 0)
print(f"Feature matrix: {X.shape}")
print(f"Labels: AD={sum(y==0)}, Control={sum(y==1)}, FTD={sum(y==2)}")
# Save raw features
X.to_csv(os.path.join(OUTPUT_DIR, 'deep_features.csv'), index=False)
np.save(os.path.join(OUTPUT_DIR, 'labels.npy'), y)
# ══════════════════════════════════════════════════════════════
# STEP 3: Feature selection & scaling
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" STEP 3: Feature selection")
print("=" * 70)
k_features = min(120, X.shape[1])
selector = SelectKBest(f_classif, k=k_features)
X_selected = selector.fit_transform(X, y)
selected_mask = selector.get_support()
selected_features = X.columns[selected_mask].tolist()
print(f"Selected {len(selected_features)} / {X.shape[1]} features")
# Top 25 features by F-score
scores = selector.scores_[selected_mask]
top_idx = np.argsort(scores)[::-1][:25]
print("\nTop 25 most discriminative features:")
for i, ix in enumerate(top_idx):
print(f" {i+1:2d}. {selected_features[ix]:45s} F={scores[ix]:.1f}")
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_selected)
# ══════════════════════════════════════════════════════════════
# STEP 4: 3-way classification (AD vs FTD vs Control)
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" STEP 4: 3-way classification (AD vs FTD vs Control)")
print("=" * 70)
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
gb = GradientBoostingClassifier(
n_estimators=300, max_depth=4, learning_rate=0.05,
subsample=0.8, random_state=42,
)
rf = RandomForestClassifier(
n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42,
)
svm = SVC(
kernel='rbf', C=10, gamma='scale', probability=True, random_state=42,
)
models_3c = {
'GradientBoosting': gb,
'RandomForest': rf,
'SVM_RBF': svm,
}
# Also build voting ensemble
voting = VotingClassifier(
estimators=[('gb', gb), ('rf', rf), ('svm', svm)],
voting='soft',
)
models_3c['VotingEnsemble'] = voting
results_3c = {}
for name, model in models_3c.items():
y_pred = cross_val_predict(model, X_scaled, y, cv=cv)
y_prob = cross_val_predict(model, X_scaled, y, cv=cv, method='predict_proba')
acc = accuracy_score(y, y_pred)
# One-vs-rest AUC
try:
auc = roc_auc_score(y, y_prob, multi_class='ovr', average='weighted')
except Exception:
auc = 0.0
results_3c[name] = {'accuracy': acc, 'auc': auc, 'y_pred': y_pred, 'y_prob': y_prob}
print(f"\n{'─' * 50}")
print(f" {name}: Accuracy = {acc:.1%} | AUC(OvR) = {auc:.3f}")
print(f"{'─' * 50}")
print(classification_report(y, y_pred, target_names=['AD', 'Control', 'FTD']))
print("Confusion matrix:")
cm = confusion_matrix(y, y_pred)
print(f" {'':>10s} pred_AD pred_Ctrl pred_FTD")
for i, lbl in enumerate(['AD', 'Control', 'FTD']):
print(f" {lbl:>10s} {cm[i,0]:7d} {cm[i,1]:9d} {cm[i,2]:8d}")
best_3c = max(results_3c, key=lambda k: results_3c[k]['accuracy'])
print(f"\n>>> Best 3-class model: {best_3c} "
f"(Acc={results_3c[best_3c]['accuracy']:.1%}, "
f"AUC={results_3c[best_3c]['auc']:.3f})")
# ══════════════════════════════════════════════════════════════
# STEP 5: Binary classification (AD vs non-AD)
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" STEP 5: Binary classification (AD vs non-AD)")
print("=" * 70)
y_binary = (y == 0).astype(int) # AD=1, non-AD=0
gb_b = GradientBoostingClassifier(
n_estimators=300, max_depth=4, learning_rate=0.05,
subsample=0.8, random_state=42,
)
rf_b = RandomForestClassifier(
n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42,
)
svm_b = SVC(
kernel='rbf', C=10, gamma='scale', probability=True, random_state=42,
)
voting_b = VotingClassifier(
estimators=[('gb', gb_b), ('rf', rf_b), ('svm', svm_b)],
voting='soft',
)
models_bin = {
'GradientBoosting': gb_b,
'RandomForest': rf_b,
'SVM_RBF': svm_b,
'VotingEnsemble': voting_b,
}
cv_bin = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
results_bin = {}
for name, model in models_bin.items():
y_pred = cross_val_predict(model, X_scaled, y_binary, cv=cv_bin)
y_prob = cross_val_predict(model, X_scaled, y_binary, cv=cv_bin, method='predict_proba')
acc = accuracy_score(y_binary, y_pred)
auc = roc_auc_score(y_binary, y_prob[:, 1])
sens = np.sum((y_pred == 1) & (y_binary == 1)) / (np.sum(y_binary == 1) + 1e-10)
spec = np.sum((y_pred == 0) & (y_binary == 0)) / (np.sum(y_binary == 0) + 1e-10)
results_bin[name] = {'accuracy': acc, 'auc': auc, 'sensitivity': sens, 'specificity': spec}
print(f"\n{'─' * 50}")
print(f" {name}: Acc={acc:.1%} AUC={auc:.3f} Sens={sens:.1%} Spec={spec:.1%}")
print(f"{'─' * 50}")
print(classification_report(y_binary, y_pred, target_names=['non-AD', 'AD']))
best_bin = max(results_bin, key=lambda k: results_bin[k]['auc'])
print(f"\n>>> Best binary model: {best_bin} "
f"(AUC={results_bin[best_bin]['auc']:.3f}, "
f"Acc={results_bin[best_bin]['accuracy']:.1%})")
# ══════════════════════════════════════════════════════════════
# STEP 6: Train final models on full data & save
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" STEP 6: Training final models & saving artifacts")
print("=" * 70)
# Final 3-class
final_3c = VotingClassifier(
estimators=[
('gb', GradientBoostingClassifier(
n_estimators=300, max_depth=4, learning_rate=0.05,
subsample=0.8, random_state=42)),
('rf', RandomForestClassifier(
n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42)),
('svm', SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42)),
],
voting='soft',
)
final_3c.fit(X_scaled, y)
# Final binary
final_bin = VotingClassifier(
estimators=[
('gb', GradientBoostingClassifier(
n_estimators=300, max_depth=4, learning_rate=0.05,
subsample=0.8, random_state=42)),
('rf', RandomForestClassifier(
n_estimators=400, max_depth=6, min_samples_leaf=2, random_state=42)),
('svm', SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42)),
],
voting='soft',
)
final_bin.fit(X_scaled, y_binary)
# Feature importance from the GradientBoosting inside the ensemble
gb_inside = final_3c.named_estimators_['gb']
importances = gb_inside.feature_importances_
top_fi = np.argsort(importances)[::-1][:20]
print("\nTop 20 feature importances (GradientBoosting, 3-class):")
for i, ix in enumerate(top_fi):
print(f" {i+1:2d}. {selected_features[ix]:45s} imp={importances[ix]:.4f}")
# Save feature importance
fi_df = pd.DataFrame({
'feature': selected_features,
'importance': importances,
}).sort_values('importance', ascending=False)
fi_df.to_csv(os.path.join(OUTPUT_DIR, 'feature_importance.csv'), index=False)
# Save models
artifacts = {
'model_3class': final_3c,
'model_binary': final_bin,
'scaler': scaler,
'selector': selector,
'feature_names': list(X.columns),
'selected_features': selected_features,
'label_names_3class': {0: 'AD', 1: 'Control', 2: 'FTD'},
'label_names_binary': {0: 'non-AD', 1: 'AD'},
'channels': CHANNELS,
'bands': BANDS,
'results_3class': {k: {kk: vv for kk, vv in v.items() if kk not in ('y_pred', 'y_prob')}
for k, v in results_3c.items()},
'results_binary': results_bin,
}
model_path = os.path.join(OUTPUT_DIR, 'eeg_deep_analysis.pkl')
with open(model_path, 'wb') as f:
pickle.dump(artifacts, f)
print(f"\nModel saved: {model_path} ({os.path.getsize(model_path)/1e6:.1f} MB)")
# Save results summary as JSON
summary = {
'dataset': 'OpenNeuro ds004504',
'n_subjects': len(all_features),
'n_failed': len(failed),
'n_features_total': X.shape[1],
'n_features_selected': len(selected_features),
'classification_3way': {
name: {
'accuracy': float(f"{v['accuracy']:.4f}"),
'auc_ovr': float(f"{v['auc']:.4f}"),
}
for name, v in results_3c.items()
},
'classification_binary_AD_vs_nonAD': {
name: {k: float(f"{vv:.4f}") for k, vv in v.items()}
for name, v in results_bin.items()
},
'best_3class_model': best_3c,
'best_binary_model': best_bin,
}
with open(os.path.join(OUTPUT_DIR, 'results_summary.json'), 'w') as f:
json.dump(summary, f, indent=2)
# ══════════════════════════════════════════════════════════════
# FINAL SUMMARY
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 70}")
print(" EEG DEEP ANALYSIS β€” COMPLETE")
print("=" * 70)
print(f" Subjects: {len(all_features)} ({sum(y==0)} AD, {sum(y==1)} Ctrl, {sum(y==2)} FTD)")
print(f" Features: {X.shape[1]} total -> {len(selected_features)} selected")
print(f" 3-class best: {best_3c} Acc={results_3c[best_3c]['accuracy']:.1%} AUC={results_3c[best_3c]['auc']:.3f}")
print(f" Binary best: {best_bin} AUC={results_bin[best_bin]['auc']:.3f} Acc={results_bin[best_bin]['accuracy']:.1%}")
print(f" Output dir: {OUTPUT_DIR}")
print(f" Author: Satyawan Singh β€” Infonova Solutions")
print("=" * 70)