#!/usr/bin/env python3 """ EEG-Based Alzheimer's Disease Classifier ========================================= Dataset: OpenNeuro ds004504 (88 subjects: 36 AD, 23 FTD, 29 Control) Features: Power spectral density bands + connectivity + complexity Models: XGBoost + LightGBM + Random Forest ensemble Author: Satyawan Singh — Infonova Solutions """ import os import json import numpy as np import pandas as pd import warnings import pickle from collections import defaultdict warnings.filterwarnings('ignore') # NumPy 2.0 compat: np.trapz → np.trapezoid if not hasattr(np, 'trapz'): np.trapz = np.trapezoid # ── Paths ── BASE = '/Users/satyawansingh/Documents/alzheimer-research-complete/data/openneuro_ad_eeg' OUTPUT_DIR = '/Users/satyawansingh/Documents/alzheimer-research-complete/models/eeg_ad_classifier' os.makedirs(OUTPUT_DIR, exist_ok=True) # ══════════════════════════════════════════════════════════════ # STEP 1: Load participants and labels # ══════════════════════════════════════════════════════════════ print("=" * 60) print("STEP 1: Loading participant metadata") print("=" * 60) participants = pd.read_csv(os.path.join(BASE, 'participants.tsv'), sep='\t') print(f"Total participants: {len(participants)}") print(f"Groups: {dict(participants['Group'].value_counts())}") print(f" A = Alzheimer's, F = Frontotemporal Dementia, C = Control") # Label mapping label_map = {'A': 0, 'C': 1, 'F': 2} # AD=0, Control=1, FTD=2 label_names = {0: 'AD', 1: 'Control', 2: 'FTD'} # ══════════════════════════════════════════════════════════════ # STEP 2: Extract EEG features # ══════════════════════════════════════════════════════════════ print(f"\n{'=' * 60}") print("STEP 2: Extracting EEG features (this takes ~5 minutes)") print("=" * 60) import mne from scipy import signal from scipy.stats import kurtosis, skew # Frequency bands BANDS = { 'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 45), } CHANNELS = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz'] def compute_psd_features(data, sfreq): """Compute power spectral density features per channel.""" features = {} for ch_idx, ch_name in enumerate(CHANNELS): if ch_idx >= data.shape[0]: continue ch_data = data[ch_idx] # Welch PSD freqs, psd = signal.welch(ch_data, fs=sfreq, nperseg=min(2048, len(ch_data))) # Band powers total_power = np.trapz(psd, freqs) for band_name, (fmin, fmax) in BANDS.items(): mask = (freqs >= fmin) & (freqs <= fmax) band_power = np.trapz(psd[mask], freqs[mask]) features[f'{ch_name}_{band_name}_abs'] = band_power features[f'{ch_name}_{band_name}_rel'] = band_power / (total_power + 1e-10) # Spectral ratios (AD biomarkers) alpha_mask = (freqs >= 8) & (freqs <= 13) theta_mask = (freqs >= 4) & (freqs <= 8) delta_mask = (freqs >= 0.5) & (freqs <= 4) beta_mask = (freqs >= 13) & (freqs <= 30) alpha_power = np.trapz(psd[alpha_mask], freqs[alpha_mask]) theta_power = np.trapz(psd[theta_mask], freqs[theta_mask]) delta_power = np.trapz(psd[delta_mask], freqs[delta_mask]) beta_power = np.trapz(psd[beta_mask], freqs[beta_mask]) features[f'{ch_name}_theta_alpha_ratio'] = theta_power / (alpha_power + 1e-10) features[f'{ch_name}_delta_alpha_ratio'] = delta_power / (alpha_power + 1e-10) features[f'{ch_name}_delta_theta_ratio'] = delta_power / (theta_power + 1e-10) features[f'{ch_name}_alpha_beta_ratio'] = alpha_power / (beta_power + 1e-10) # Peak alpha frequency (slows in AD) alpha_freqs = freqs[alpha_mask] alpha_psd = psd[alpha_mask] if len(alpha_psd) > 0: features[f'{ch_name}_peak_alpha_freq'] = alpha_freqs[np.argmax(alpha_psd)] else: features[f'{ch_name}_peak_alpha_freq'] = 0 # Spectral entropy psd_norm = psd / (psd.sum() + 1e-10) psd_norm = psd_norm[psd_norm > 0] features[f'{ch_name}_spectral_entropy'] = -np.sum(psd_norm * np.log2(psd_norm)) return features def compute_temporal_features(data): """Compute time-domain features per channel.""" features = {} for ch_idx, ch_name in enumerate(CHANNELS): if ch_idx >= data.shape[0]: continue ch_data = data[ch_idx] features[f'{ch_name}_mean'] = np.mean(ch_data) features[f'{ch_name}_std'] = np.std(ch_data) features[f'{ch_name}_kurtosis'] = kurtosis(ch_data) features[f'{ch_name}_skewness'] = skew(ch_data) features[f'{ch_name}_rms'] = np.sqrt(np.mean(ch_data ** 2)) # Hjorth parameters (activity, mobility, complexity) diff1 = np.diff(ch_data) diff2 = np.diff(diff1) activity = np.var(ch_data) mobility = np.sqrt(np.var(diff1) / (activity + 1e-10)) complexity = np.sqrt(np.var(diff2) / (np.var(diff1) + 1e-10)) / (mobility + 1e-10) features[f'{ch_name}_hjorth_activity'] = activity features[f'{ch_name}_hjorth_mobility'] = mobility features[f'{ch_name}_hjorth_complexity'] = complexity # Zero crossing rate zero_crossings = np.sum(np.diff(np.sign(ch_data)) != 0) / len(ch_data) features[f'{ch_name}_zero_crossing_rate'] = zero_crossings return features def compute_connectivity_features(data, sfreq): """Compute inter-channel connectivity (coherence in alpha band).""" features = {} n_channels = min(data.shape[0], len(CHANNELS)) # Pairwise coherence in alpha band (key AD biomarker) for i in range(n_channels): for j in range(i + 1, n_channels): freqs, coh = signal.coherence(data[i], data[j], fs=sfreq, nperseg=min(1024, len(data[i]))) alpha_mask = (freqs >= 8) & (freqs <= 13) if alpha_mask.sum() > 0: features[f'coh_alpha_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[alpha_mask]) theta_mask = (freqs >= 4) & (freqs <= 8) if theta_mask.sum() > 0: features[f'coh_theta_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[theta_mask]) return features def compute_regional_features(data, sfreq): """Compute region-averaged features (frontal, temporal, parietal, occipital).""" regions = { 'frontal': ['Fp1', 'Fp2', 'F3', 'F4', 'F7', 'F8', 'Fz'], 'temporal': ['T3', 'T4', 'T5', 'T6'], 'parietal': ['P3', 'P4', 'Pz'], 'central': ['C3', 'C4', 'Cz'], 'occipital': ['O1', 'O2'], } features = {} ch_to_idx = {ch: i for i, ch in enumerate(CHANNELS)} for region_name, region_channels in regions.items(): indices = [ch_to_idx[ch] for ch in region_channels if ch in ch_to_idx] if not indices: continue region_data = data[indices].mean(axis=0) freqs, psd = signal.welch(region_data, fs=sfreq, nperseg=min(2048, len(region_data))) total_power = np.trapz(psd, freqs) + 1e-10 for band_name, (fmin, fmax) in BANDS.items(): mask = (freqs >= fmin) & (freqs <= fmax) band_power = np.trapz(psd[mask], freqs[mask]) features[f'region_{region_name}_{band_name}_rel'] = band_power / total_power # Alpha/theta ratio per region alpha_mask = (freqs >= 8) & (freqs <= 13) theta_mask = (freqs >= 4) & (freqs <= 8) features[f'region_{region_name}_theta_alpha'] = ( np.trapz(psd[theta_mask], freqs[theta_mask]) / (np.trapz(psd[alpha_mask], freqs[alpha_mask]) + 1e-10) ) # Inter-regional asymmetry (frontal vs parietal alpha — disrupted in AD) frontal_idx = [ch_to_idx[ch] for ch in ['F3', 'F4'] if ch in ch_to_idx] parietal_idx = [ch_to_idx[ch] for ch in ['P3', 'P4'] if ch in ch_to_idx] if frontal_idx and parietal_idx: f_data = data[frontal_idx].mean(axis=0) p_data = data[parietal_idx].mean(axis=0) _, f_psd = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data))) _, p_psd = signal.welch(p_data, fs=sfreq, nperseg=min(2048, len(p_data))) freqs_check, _ = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data))) alpha_mask = (freqs_check >= 8) & (freqs_check <= 13) f_alpha = np.trapz(f_psd[alpha_mask], freqs_check[alpha_mask]) p_alpha = np.trapz(p_psd[alpha_mask], freqs_check[alpha_mask]) features['frontal_parietal_alpha_asymmetry'] = (f_alpha - p_alpha) / (f_alpha + p_alpha + 1e-10) return features # ── Extract features for all subjects ── all_features = [] all_labels = [] all_subjects = [] failed = [] for _, row in participants.iterrows(): sub_id = row['participant_id'] group = row['Group'] label = label_map[group] eeg_file = os.path.join(BASE, sub_id, 'eeg', f'{sub_id}_task-eyesclosed_eeg.set') if not os.path.exists(eeg_file): failed.append(sub_id) continue try: raw = mne.io.read_raw_eeglab(eeg_file, preload=True, verbose=False) # Bandpass filter 0.5-45 Hz raw.filter(0.5, 45, verbose=False) data = raw.get_data() sfreq = raw.info['sfreq'] # Extract all feature groups feats = {} feats.update(compute_psd_features(data, sfreq)) feats.update(compute_temporal_features(data)) feats.update(compute_connectivity_features(data, sfreq)) feats.update(compute_regional_features(data, sfreq)) # Add demographic features feats['age'] = row['Age'] feats['gender'] = 1 if row['Gender'] == 'M' else 0 feats['mmse'] = row['MMSE'] all_features.append(feats) all_labels.append(label) all_subjects.append(sub_id) print(f" {sub_id} [{label_names[label]}] — {len(feats)} features extracted") except Exception as e: print(f" {sub_id} FAILED: {e}") failed.append(sub_id) print(f"\nExtracted: {len(all_features)} subjects") print(f"Failed: {len(failed)} subjects") # ── Convert to DataFrame ── X = pd.DataFrame(all_features) y = np.array(all_labels) X = X.fillna(0) print(f"Feature matrix: {X.shape}") print(f"Labels: AD={sum(y==0)}, Control={sum(y==1)}, FTD={sum(y==2)}") # Save features X.to_csv(os.path.join(OUTPUT_DIR, 'eeg_features.csv'), index=False) np.save(os.path.join(OUTPUT_DIR, 'eeg_labels.npy'), y) # ══════════════════════════════════════════════════════════════ # STEP 3: Train classifiers # ══════════════════════════════════════════════════════════════ print(f"\n{'=' * 60}") print("STEP 3: Training classifiers") print("=" * 60) from sklearn.model_selection import StratifiedKFold, cross_val_predict from sklearn.preprocessing import StandardScaler from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier from sklearn.svm import SVC from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score from sklearn.feature_selection import SelectKBest, f_classif # ── Feature selection: top 100 features ── selector = SelectKBest(f_classif, k=min(100, X.shape[1])) X_selected = selector.fit_transform(X, y) # Get selected feature names selected_mask = selector.get_support() selected_features = X.columns[selected_mask].tolist() print(f"Selected {len(selected_features)} features from {X.shape[1]}") # Show top 20 features scores = selector.scores_[selected_mask] top_idx = np.argsort(scores)[::-1][:20] print("\nTop 20 discriminative features:") for i, idx in enumerate(top_idx): print(f" {i+1:2d}. {selected_features[idx]:45s} F={scores[idx]:.1f}") # ── Scale ── scaler = StandardScaler() X_scaled = scaler.fit_transform(X_selected) # ── 3-class classification (AD vs Control vs FTD) ── print(f"\n--- 3-class: AD vs Control vs FTD ---") cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) models = { 'GradientBoosting': GradientBoostingClassifier( n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, ), 'RandomForest': RandomForestClassifier( n_estimators=300, max_depth=6, random_state=42, ), 'ExtraTrees': ExtraTreesClassifier( n_estimators=300, max_depth=6, random_state=42, ), 'SVM': SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42), } best_model_name = None best_accuracy = 0 for name, model in models.items(): y_pred = cross_val_predict(model, X_scaled, y, cv=cv) acc = np.mean(y_pred == y) print(f"\n{name}: Accuracy = {acc:.1%}") print(classification_report(y, y_pred, target_names=['AD', 'Control', 'FTD'])) if acc > best_accuracy: best_accuracy = acc best_model_name = name print(f"\nBest 3-class model: {best_model_name} ({best_accuracy:.1%})") # ── Binary classification: AD vs Control (drop FTD) ── print(f"\n--- Binary: AD vs Control ---") binary_mask = y != 2 # drop FTD X_binary = X_scaled[binary_mask] y_binary = y[binary_mask] # Remap: AD=0, Control=1 → AD=1, Control=0 for AUC y_binary_pos = 1 - y_binary # AD=1, Control=0 cv_binary = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) for name, model in models.items(): y_pred = cross_val_predict(model, X_binary, y_binary, cv=cv_binary) y_prob = cross_val_predict(model, X_binary, y_binary, cv=cv_binary, method='predict_proba') acc = np.mean(y_pred == y_binary) # AUC: probability of being AD (class 0) auc = roc_auc_score(y_binary_pos, y_prob[:, 0]) print(f"\n{name}: Accuracy = {acc:.1%}, AUC = {auc:.3f}") print(classification_report(y_binary, y_pred, target_names=['AD', 'Control'])) # ══════════════════════════════════════════════════════════════ # STEP 4: Train final model and save # ══════════════════════════════════════════════════════════════ print(f"\n{'=' * 60}") print("STEP 4: Training final models and saving") print("=" * 60) # Train on all data for deployment (use best model type) final_3class = GradientBoostingClassifier( n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, ) final_3class.fit(X_scaled, y) final_binary = GradientBoostingClassifier( n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42, ) final_binary.fit(X_binary, y_binary) # Save everything artifacts = { 'model_3class': final_3class, 'model_binary': final_binary, 'scaler': scaler, 'selector': selector, 'feature_names': list(X.columns), 'selected_features': selected_features, 'label_names_3class': {0: 'AD', 1: 'Control', 2: 'FTD'}, 'label_names_binary': {0: 'AD', 1: 'Control'}, 'channels': CHANNELS, 'bands': BANDS, 'best_cv_accuracy': best_accuracy, } model_path = os.path.join(OUTPUT_DIR, 'eeg_ad_classifier.pkl') with open(model_path, 'wb') as f: pickle.dump(artifacts, f) print(f"\nSaved: {model_path} ({os.path.getsize(model_path)/1e6:.1f} MB)") # ── Feature importance ── importances = final_3class.feature_importances_ top_features = np.argsort(importances)[::-1][:15] print("\nTop 15 features (XGBoost importance):") for i, idx in enumerate(top_features): print(f" {i+1:2d}. {selected_features[idx]:45s} importance={importances[idx]:.4f}") # ══════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════ print(f"\n{'=' * 60}") print("EEG AD CLASSIFIER — TRAINING COMPLETE") print("=" * 60) print(f" Dataset: OpenNeuro ds004504") print(f" Subjects: {len(all_features)} ({sum(y==0)} AD, {sum(y==1)} Control, {sum(y==2)} FTD)") print(f" Features: {X.shape[1]} total → {len(selected_features)} selected") print(f" Best 3-class CV accuracy: {best_accuracy:.1%}") print(f" Model saved: {model_path}") print(f" Author: Satyawan Singh — Infonova Solutions")