| |
| """ |
| EEG-Based Alzheimer's Disease Classifier |
| ========================================= |
| Dataset: OpenNeuro ds004504 (88 subjects: 36 AD, 23 FTD, 29 Control) |
| Features: Power spectral density bands + connectivity + complexity |
| Models: XGBoost + LightGBM + Random Forest ensemble |
| |
| Author: Satyawan Singh β Infonova Solutions |
| """ |
|
|
| import os |
| import json |
| import numpy as np |
| import pandas as pd |
| import warnings |
| import pickle |
| from collections import defaultdict |
|
|
| warnings.filterwarnings('ignore') |
|
|
| |
| if not hasattr(np, 'trapz'): |
| np.trapz = np.trapezoid |
|
|
| |
| BASE = '/Users/satyawansingh/Documents/alzheimer-research-complete/data/openneuro_ad_eeg' |
| OUTPUT_DIR = '/Users/satyawansingh/Documents/alzheimer-research-complete/models/eeg_ad_classifier' |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| |
| |
| print("=" * 60) |
| print("STEP 1: Loading participant metadata") |
| print("=" * 60) |
|
|
| participants = pd.read_csv(os.path.join(BASE, 'participants.tsv'), sep='\t') |
| print(f"Total participants: {len(participants)}") |
| print(f"Groups: {dict(participants['Group'].value_counts())}") |
| print(f" A = Alzheimer's, F = Frontotemporal Dementia, C = Control") |
|
|
| |
| label_map = {'A': 0, 'C': 1, 'F': 2} |
| label_names = {0: 'AD', 1: 'Control', 2: 'FTD'} |
|
|
| |
| |
| |
| print(f"\n{'=' * 60}") |
| print("STEP 2: Extracting EEG features (this takes ~5 minutes)") |
| print("=" * 60) |
|
|
| import mne |
| from scipy import signal |
| from scipy.stats import kurtosis, skew |
|
|
| |
| BANDS = { |
| 'delta': (0.5, 4), |
| 'theta': (4, 8), |
| 'alpha': (8, 13), |
| 'beta': (13, 30), |
| 'gamma': (30, 45), |
| } |
|
|
| CHANNELS = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', |
| 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz'] |
|
|
|
|
| def compute_psd_features(data, sfreq): |
| """Compute power spectral density features per channel.""" |
| features = {} |
| for ch_idx, ch_name in enumerate(CHANNELS): |
| if ch_idx >= data.shape[0]: |
| continue |
| ch_data = data[ch_idx] |
|
|
| |
| freqs, psd = signal.welch(ch_data, fs=sfreq, nperseg=min(2048, len(ch_data))) |
|
|
| |
| total_power = np.trapz(psd, freqs) |
| for band_name, (fmin, fmax) in BANDS.items(): |
| mask = (freqs >= fmin) & (freqs <= fmax) |
| band_power = np.trapz(psd[mask], freqs[mask]) |
| features[f'{ch_name}_{band_name}_abs'] = band_power |
| features[f'{ch_name}_{band_name}_rel'] = band_power / (total_power + 1e-10) |
|
|
| |
| alpha_mask = (freqs >= 8) & (freqs <= 13) |
| theta_mask = (freqs >= 4) & (freqs <= 8) |
| delta_mask = (freqs >= 0.5) & (freqs <= 4) |
| beta_mask = (freqs >= 13) & (freqs <= 30) |
|
|
| alpha_power = np.trapz(psd[alpha_mask], freqs[alpha_mask]) |
| theta_power = np.trapz(psd[theta_mask], freqs[theta_mask]) |
| delta_power = np.trapz(psd[delta_mask], freqs[delta_mask]) |
| beta_power = np.trapz(psd[beta_mask], freqs[beta_mask]) |
|
|
| features[f'{ch_name}_theta_alpha_ratio'] = theta_power / (alpha_power + 1e-10) |
| features[f'{ch_name}_delta_alpha_ratio'] = delta_power / (alpha_power + 1e-10) |
| features[f'{ch_name}_delta_theta_ratio'] = delta_power / (theta_power + 1e-10) |
| features[f'{ch_name}_alpha_beta_ratio'] = alpha_power / (beta_power + 1e-10) |
|
|
| |
| alpha_freqs = freqs[alpha_mask] |
| alpha_psd = psd[alpha_mask] |
| if len(alpha_psd) > 0: |
| features[f'{ch_name}_peak_alpha_freq'] = alpha_freqs[np.argmax(alpha_psd)] |
| else: |
| features[f'{ch_name}_peak_alpha_freq'] = 0 |
|
|
| |
| psd_norm = psd / (psd.sum() + 1e-10) |
| psd_norm = psd_norm[psd_norm > 0] |
| features[f'{ch_name}_spectral_entropy'] = -np.sum(psd_norm * np.log2(psd_norm)) |
|
|
| return features |
|
|
|
|
| def compute_temporal_features(data): |
| """Compute time-domain features per channel.""" |
| features = {} |
| for ch_idx, ch_name in enumerate(CHANNELS): |
| if ch_idx >= data.shape[0]: |
| continue |
| ch_data = data[ch_idx] |
|
|
| features[f'{ch_name}_mean'] = np.mean(ch_data) |
| features[f'{ch_name}_std'] = np.std(ch_data) |
| features[f'{ch_name}_kurtosis'] = kurtosis(ch_data) |
| features[f'{ch_name}_skewness'] = skew(ch_data) |
| features[f'{ch_name}_rms'] = np.sqrt(np.mean(ch_data ** 2)) |
|
|
| |
| diff1 = np.diff(ch_data) |
| diff2 = np.diff(diff1) |
| activity = np.var(ch_data) |
| mobility = np.sqrt(np.var(diff1) / (activity + 1e-10)) |
| complexity = np.sqrt(np.var(diff2) / (np.var(diff1) + 1e-10)) / (mobility + 1e-10) |
| features[f'{ch_name}_hjorth_activity'] = activity |
| features[f'{ch_name}_hjorth_mobility'] = mobility |
| features[f'{ch_name}_hjorth_complexity'] = complexity |
|
|
| |
| zero_crossings = np.sum(np.diff(np.sign(ch_data)) != 0) / len(ch_data) |
| features[f'{ch_name}_zero_crossing_rate'] = zero_crossings |
|
|
| return features |
|
|
|
|
| def compute_connectivity_features(data, sfreq): |
| """Compute inter-channel connectivity (coherence in alpha band).""" |
| features = {} |
| n_channels = min(data.shape[0], len(CHANNELS)) |
|
|
| |
| for i in range(n_channels): |
| for j in range(i + 1, n_channels): |
| freqs, coh = signal.coherence(data[i], data[j], fs=sfreq, nperseg=min(1024, len(data[i]))) |
| alpha_mask = (freqs >= 8) & (freqs <= 13) |
| if alpha_mask.sum() > 0: |
| features[f'coh_alpha_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[alpha_mask]) |
| theta_mask = (freqs >= 4) & (freqs <= 8) |
| if theta_mask.sum() > 0: |
| features[f'coh_theta_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[theta_mask]) |
|
|
| return features |
|
|
|
|
| def compute_regional_features(data, sfreq): |
| """Compute region-averaged features (frontal, temporal, parietal, occipital).""" |
| regions = { |
| 'frontal': ['Fp1', 'Fp2', 'F3', 'F4', 'F7', 'F8', 'Fz'], |
| 'temporal': ['T3', 'T4', 'T5', 'T6'], |
| 'parietal': ['P3', 'P4', 'Pz'], |
| 'central': ['C3', 'C4', 'Cz'], |
| 'occipital': ['O1', 'O2'], |
| } |
|
|
| features = {} |
| ch_to_idx = {ch: i for i, ch in enumerate(CHANNELS)} |
|
|
| for region_name, region_channels in regions.items(): |
| indices = [ch_to_idx[ch] for ch in region_channels if ch in ch_to_idx] |
| if not indices: |
| continue |
| region_data = data[indices].mean(axis=0) |
|
|
| freqs, psd = signal.welch(region_data, fs=sfreq, nperseg=min(2048, len(region_data))) |
| total_power = np.trapz(psd, freqs) + 1e-10 |
|
|
| for band_name, (fmin, fmax) in BANDS.items(): |
| mask = (freqs >= fmin) & (freqs <= fmax) |
| band_power = np.trapz(psd[mask], freqs[mask]) |
| features[f'region_{region_name}_{band_name}_rel'] = band_power / total_power |
|
|
| |
| alpha_mask = (freqs >= 8) & (freqs <= 13) |
| theta_mask = (freqs >= 4) & (freqs <= 8) |
| features[f'region_{region_name}_theta_alpha'] = ( |
| np.trapz(psd[theta_mask], freqs[theta_mask]) / |
| (np.trapz(psd[alpha_mask], freqs[alpha_mask]) + 1e-10) |
| ) |
|
|
| |
| frontal_idx = [ch_to_idx[ch] for ch in ['F3', 'F4'] if ch in ch_to_idx] |
| parietal_idx = [ch_to_idx[ch] for ch in ['P3', 'P4'] if ch in ch_to_idx] |
| if frontal_idx and parietal_idx: |
| f_data = data[frontal_idx].mean(axis=0) |
| p_data = data[parietal_idx].mean(axis=0) |
| _, f_psd = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data))) |
| _, p_psd = signal.welch(p_data, fs=sfreq, nperseg=min(2048, len(p_data))) |
| freqs_check, _ = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data))) |
| alpha_mask = (freqs_check >= 8) & (freqs_check <= 13) |
| f_alpha = np.trapz(f_psd[alpha_mask], freqs_check[alpha_mask]) |
| p_alpha = np.trapz(p_psd[alpha_mask], freqs_check[alpha_mask]) |
| features['frontal_parietal_alpha_asymmetry'] = (f_alpha - p_alpha) / (f_alpha + p_alpha + 1e-10) |
|
|
| return features |
|
|
|
|
| |
| all_features = [] |
| all_labels = [] |
| all_subjects = [] |
| failed = [] |
|
|
| for _, row in participants.iterrows(): |
| sub_id = row['participant_id'] |
| group = row['Group'] |
| label = label_map[group] |
|
|
| eeg_file = os.path.join(BASE, sub_id, 'eeg', f'{sub_id}_task-eyesclosed_eeg.set') |
| if not os.path.exists(eeg_file): |
| failed.append(sub_id) |
| continue |
|
|
| try: |
| raw = mne.io.read_raw_eeglab(eeg_file, preload=True, verbose=False) |
|
|
| |
| raw.filter(0.5, 45, verbose=False) |
|
|
| data = raw.get_data() |
| sfreq = raw.info['sfreq'] |
|
|
| |
| feats = {} |
| feats.update(compute_psd_features(data, sfreq)) |
| feats.update(compute_temporal_features(data)) |
| feats.update(compute_connectivity_features(data, sfreq)) |
| feats.update(compute_regional_features(data, sfreq)) |
|
|
| |
| feats['age'] = row['Age'] |
| feats['gender'] = 1 if row['Gender'] == 'M' else 0 |
| feats['mmse'] = row['MMSE'] |
|
|
| all_features.append(feats) |
| all_labels.append(label) |
| all_subjects.append(sub_id) |
|
|
| print(f" {sub_id} [{label_names[label]}] β {len(feats)} features extracted") |
|
|
| except Exception as e: |
| print(f" {sub_id} FAILED: {e}") |
| failed.append(sub_id) |
|
|
| print(f"\nExtracted: {len(all_features)} subjects") |
| print(f"Failed: {len(failed)} subjects") |
|
|
| |
| X = pd.DataFrame(all_features) |
| y = np.array(all_labels) |
| X = X.fillna(0) |
|
|
| print(f"Feature matrix: {X.shape}") |
| print(f"Labels: AD={sum(y==0)}, Control={sum(y==1)}, FTD={sum(y==2)}") |
|
|
| |
| X.to_csv(os.path.join(OUTPUT_DIR, 'eeg_features.csv'), index=False) |
| np.save(os.path.join(OUTPUT_DIR, 'eeg_labels.npy'), y) |
|
|
| |
| |
| |
| print(f"\n{'=' * 60}") |
| print("STEP 3: Training classifiers") |
| print("=" * 60) |
|
|
| from sklearn.model_selection import StratifiedKFold, cross_val_predict |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier |
| from sklearn.svm import SVC |
| from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score |
| from sklearn.feature_selection import SelectKBest, f_classif |
|
|
| |
| selector = SelectKBest(f_classif, k=min(100, X.shape[1])) |
| X_selected = selector.fit_transform(X, y) |
|
|
| |
| selected_mask = selector.get_support() |
| selected_features = X.columns[selected_mask].tolist() |
| print(f"Selected {len(selected_features)} features from {X.shape[1]}") |
|
|
| |
| scores = selector.scores_[selected_mask] |
| top_idx = np.argsort(scores)[::-1][:20] |
| print("\nTop 20 discriminative features:") |
| for i, idx in enumerate(top_idx): |
| print(f" {i+1:2d}. {selected_features[idx]:45s} F={scores[idx]:.1f}") |
|
|
| |
| scaler = StandardScaler() |
| X_scaled = scaler.fit_transform(X_selected) |
|
|
| |
| print(f"\n--- 3-class: AD vs Control vs FTD ---") |
|
|
| cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) |
|
|
| models = { |
| 'GradientBoosting': GradientBoostingClassifier( |
| n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, |
| ), |
| 'RandomForest': RandomForestClassifier( |
| n_estimators=300, max_depth=6, random_state=42, |
| ), |
| 'ExtraTrees': ExtraTreesClassifier( |
| n_estimators=300, max_depth=6, random_state=42, |
| ), |
| 'SVM': SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42), |
| } |
|
|
| best_model_name = None |
| best_accuracy = 0 |
|
|
| for name, model in models.items(): |
| y_pred = cross_val_predict(model, X_scaled, y, cv=cv) |
| acc = np.mean(y_pred == y) |
| print(f"\n{name}: Accuracy = {acc:.1%}") |
| print(classification_report(y, y_pred, target_names=['AD', 'Control', 'FTD'])) |
|
|
| if acc > best_accuracy: |
| best_accuracy = acc |
| best_model_name = name |
|
|
| print(f"\nBest 3-class model: {best_model_name} ({best_accuracy:.1%})") |
|
|
| |
| print(f"\n--- Binary: AD vs Control ---") |
|
|
| binary_mask = y != 2 |
| X_binary = X_scaled[binary_mask] |
| y_binary = y[binary_mask] |
| |
| y_binary_pos = 1 - y_binary |
|
|
| cv_binary = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) |
|
|
| for name, model in models.items(): |
| y_pred = cross_val_predict(model, X_binary, y_binary, cv=cv_binary) |
| y_prob = cross_val_predict(model, X_binary, y_binary, cv=cv_binary, method='predict_proba') |
|
|
| acc = np.mean(y_pred == y_binary) |
| |
| auc = roc_auc_score(y_binary_pos, y_prob[:, 0]) |
| print(f"\n{name}: Accuracy = {acc:.1%}, AUC = {auc:.3f}") |
| print(classification_report(y_binary, y_pred, target_names=['AD', 'Control'])) |
|
|
| |
| |
| |
| print(f"\n{'=' * 60}") |
| print("STEP 4: Training final models and saving") |
| print("=" * 60) |
|
|
| |
| final_3class = GradientBoostingClassifier( |
| n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, |
| ) |
| final_3class.fit(X_scaled, y) |
|
|
| final_binary = GradientBoostingClassifier( |
| n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, |
| ) |
| final_binary.fit(X_binary, y_binary) |
|
|
| |
| artifacts = { |
| 'model_3class': final_3class, |
| 'model_binary': final_binary, |
| 'scaler': scaler, |
| 'selector': selector, |
| 'feature_names': list(X.columns), |
| 'selected_features': selected_features, |
| 'label_names_3class': {0: 'AD', 1: 'Control', 2: 'FTD'}, |
| 'label_names_binary': {0: 'AD', 1: 'Control'}, |
| 'channels': CHANNELS, |
| 'bands': BANDS, |
| 'best_cv_accuracy': best_accuracy, |
| } |
|
|
| model_path = os.path.join(OUTPUT_DIR, 'eeg_ad_classifier.pkl') |
| with open(model_path, 'wb') as f: |
| pickle.dump(artifacts, f) |
|
|
| print(f"\nSaved: {model_path} ({os.path.getsize(model_path)/1e6:.1f} MB)") |
|
|
| |
| importances = final_3class.feature_importances_ |
| top_features = np.argsort(importances)[::-1][:15] |
| print("\nTop 15 features (XGBoost importance):") |
| for i, idx in enumerate(top_features): |
| print(f" {i+1:2d}. {selected_features[idx]:45s} importance={importances[idx]:.4f}") |
|
|
| |
| |
| |
| print(f"\n{'=' * 60}") |
| print("EEG AD CLASSIFIER β TRAINING COMPLETE") |
| print("=" * 60) |
| print(f" Dataset: OpenNeuro ds004504") |
| print(f" Subjects: {len(all_features)} ({sum(y==0)} AD, {sum(y==1)} Control, {sum(y==2)} FTD)") |
| print(f" Features: {X.shape[1]} total β {len(selected_features)} selected") |
| print(f" Best 3-class CV accuracy: {best_accuracy:.1%}") |
| print(f" Model saved: {model_path}") |
| print(f" Author: Satyawan Singh β Infonova Solutions") |
|
|