alzheimer-research-complete / train_eeg_ad_classifier.py
Satyawan1's picture
Upload train_eeg_ad_classifier.py with huggingface_hub
425c0e6 verified
#!/usr/bin/env python3
"""
EEG-Based Alzheimer's Disease Classifier
=========================================
Dataset: OpenNeuro ds004504 (88 subjects: 36 AD, 23 FTD, 29 Control)
Features: Power spectral density bands + connectivity + complexity
Models: XGBoost + LightGBM + Random Forest ensemble
Author: Satyawan Singh β€” Infonova Solutions
"""
import os
import json
import numpy as np
import pandas as pd
import warnings
import pickle
from collections import defaultdict
warnings.filterwarnings('ignore')
# NumPy 2.0 compat: np.trapz β†’ np.trapezoid
if not hasattr(np, 'trapz'):
np.trapz = np.trapezoid
# ── Paths ──
BASE = '/Users/satyawansingh/Documents/alzheimer-research-complete/data/openneuro_ad_eeg'
OUTPUT_DIR = '/Users/satyawansingh/Documents/alzheimer-research-complete/models/eeg_ad_classifier'
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ══════════════════════════════════════════════════════════════
# STEP 1: Load participants and labels
# ══════════════════════════════════════════════════════════════
print("=" * 60)
print("STEP 1: Loading participant metadata")
print("=" * 60)
participants = pd.read_csv(os.path.join(BASE, 'participants.tsv'), sep='\t')
print(f"Total participants: {len(participants)}")
print(f"Groups: {dict(participants['Group'].value_counts())}")
print(f" A = Alzheimer's, F = Frontotemporal Dementia, C = Control")
# Label mapping
label_map = {'A': 0, 'C': 1, 'F': 2} # AD=0, Control=1, FTD=2
label_names = {0: 'AD', 1: 'Control', 2: 'FTD'}
# ══════════════════════════════════════════════════════════════
# STEP 2: Extract EEG features
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 60}")
print("STEP 2: Extracting EEG features (this takes ~5 minutes)")
print("=" * 60)
import mne
from scipy import signal
from scipy.stats import kurtosis, skew
# Frequency bands
BANDS = {
'delta': (0.5, 4),
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 45),
}
CHANNELS = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4',
'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz']
def compute_psd_features(data, sfreq):
"""Compute power spectral density features per channel."""
features = {}
for ch_idx, ch_name in enumerate(CHANNELS):
if ch_idx >= data.shape[0]:
continue
ch_data = data[ch_idx]
# Welch PSD
freqs, psd = signal.welch(ch_data, fs=sfreq, nperseg=min(2048, len(ch_data)))
# Band powers
total_power = np.trapz(psd, freqs)
for band_name, (fmin, fmax) in BANDS.items():
mask = (freqs >= fmin) & (freqs <= fmax)
band_power = np.trapz(psd[mask], freqs[mask])
features[f'{ch_name}_{band_name}_abs'] = band_power
features[f'{ch_name}_{band_name}_rel'] = band_power / (total_power + 1e-10)
# Spectral ratios (AD biomarkers)
alpha_mask = (freqs >= 8) & (freqs <= 13)
theta_mask = (freqs >= 4) & (freqs <= 8)
delta_mask = (freqs >= 0.5) & (freqs <= 4)
beta_mask = (freqs >= 13) & (freqs <= 30)
alpha_power = np.trapz(psd[alpha_mask], freqs[alpha_mask])
theta_power = np.trapz(psd[theta_mask], freqs[theta_mask])
delta_power = np.trapz(psd[delta_mask], freqs[delta_mask])
beta_power = np.trapz(psd[beta_mask], freqs[beta_mask])
features[f'{ch_name}_theta_alpha_ratio'] = theta_power / (alpha_power + 1e-10)
features[f'{ch_name}_delta_alpha_ratio'] = delta_power / (alpha_power + 1e-10)
features[f'{ch_name}_delta_theta_ratio'] = delta_power / (theta_power + 1e-10)
features[f'{ch_name}_alpha_beta_ratio'] = alpha_power / (beta_power + 1e-10)
# Peak alpha frequency (slows in AD)
alpha_freqs = freqs[alpha_mask]
alpha_psd = psd[alpha_mask]
if len(alpha_psd) > 0:
features[f'{ch_name}_peak_alpha_freq'] = alpha_freqs[np.argmax(alpha_psd)]
else:
features[f'{ch_name}_peak_alpha_freq'] = 0
# Spectral entropy
psd_norm = psd / (psd.sum() + 1e-10)
psd_norm = psd_norm[psd_norm > 0]
features[f'{ch_name}_spectral_entropy'] = -np.sum(psd_norm * np.log2(psd_norm))
return features
def compute_temporal_features(data):
"""Compute time-domain features per channel."""
features = {}
for ch_idx, ch_name in enumerate(CHANNELS):
if ch_idx >= data.shape[0]:
continue
ch_data = data[ch_idx]
features[f'{ch_name}_mean'] = np.mean(ch_data)
features[f'{ch_name}_std'] = np.std(ch_data)
features[f'{ch_name}_kurtosis'] = kurtosis(ch_data)
features[f'{ch_name}_skewness'] = skew(ch_data)
features[f'{ch_name}_rms'] = np.sqrt(np.mean(ch_data ** 2))
# Hjorth parameters (activity, mobility, complexity)
diff1 = np.diff(ch_data)
diff2 = np.diff(diff1)
activity = np.var(ch_data)
mobility = np.sqrt(np.var(diff1) / (activity + 1e-10))
complexity = np.sqrt(np.var(diff2) / (np.var(diff1) + 1e-10)) / (mobility + 1e-10)
features[f'{ch_name}_hjorth_activity'] = activity
features[f'{ch_name}_hjorth_mobility'] = mobility
features[f'{ch_name}_hjorth_complexity'] = complexity
# Zero crossing rate
zero_crossings = np.sum(np.diff(np.sign(ch_data)) != 0) / len(ch_data)
features[f'{ch_name}_zero_crossing_rate'] = zero_crossings
return features
def compute_connectivity_features(data, sfreq):
"""Compute inter-channel connectivity (coherence in alpha band)."""
features = {}
n_channels = min(data.shape[0], len(CHANNELS))
# Pairwise coherence in alpha band (key AD biomarker)
for i in range(n_channels):
for j in range(i + 1, n_channels):
freqs, coh = signal.coherence(data[i], data[j], fs=sfreq, nperseg=min(1024, len(data[i])))
alpha_mask = (freqs >= 8) & (freqs <= 13)
if alpha_mask.sum() > 0:
features[f'coh_alpha_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[alpha_mask])
theta_mask = (freqs >= 4) & (freqs <= 8)
if theta_mask.sum() > 0:
features[f'coh_theta_{CHANNELS[i]}_{CHANNELS[j]}'] = np.mean(coh[theta_mask])
return features
def compute_regional_features(data, sfreq):
"""Compute region-averaged features (frontal, temporal, parietal, occipital)."""
regions = {
'frontal': ['Fp1', 'Fp2', 'F3', 'F4', 'F7', 'F8', 'Fz'],
'temporal': ['T3', 'T4', 'T5', 'T6'],
'parietal': ['P3', 'P4', 'Pz'],
'central': ['C3', 'C4', 'Cz'],
'occipital': ['O1', 'O2'],
}
features = {}
ch_to_idx = {ch: i for i, ch in enumerate(CHANNELS)}
for region_name, region_channels in regions.items():
indices = [ch_to_idx[ch] for ch in region_channels if ch in ch_to_idx]
if not indices:
continue
region_data = data[indices].mean(axis=0)
freqs, psd = signal.welch(region_data, fs=sfreq, nperseg=min(2048, len(region_data)))
total_power = np.trapz(psd, freqs) + 1e-10
for band_name, (fmin, fmax) in BANDS.items():
mask = (freqs >= fmin) & (freqs <= fmax)
band_power = np.trapz(psd[mask], freqs[mask])
features[f'region_{region_name}_{band_name}_rel'] = band_power / total_power
# Alpha/theta ratio per region
alpha_mask = (freqs >= 8) & (freqs <= 13)
theta_mask = (freqs >= 4) & (freqs <= 8)
features[f'region_{region_name}_theta_alpha'] = (
np.trapz(psd[theta_mask], freqs[theta_mask]) /
(np.trapz(psd[alpha_mask], freqs[alpha_mask]) + 1e-10)
)
# Inter-regional asymmetry (frontal vs parietal alpha β€” disrupted in AD)
frontal_idx = [ch_to_idx[ch] for ch in ['F3', 'F4'] if ch in ch_to_idx]
parietal_idx = [ch_to_idx[ch] for ch in ['P3', 'P4'] if ch in ch_to_idx]
if frontal_idx and parietal_idx:
f_data = data[frontal_idx].mean(axis=0)
p_data = data[parietal_idx].mean(axis=0)
_, f_psd = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data)))
_, p_psd = signal.welch(p_data, fs=sfreq, nperseg=min(2048, len(p_data)))
freqs_check, _ = signal.welch(f_data, fs=sfreq, nperseg=min(2048, len(f_data)))
alpha_mask = (freqs_check >= 8) & (freqs_check <= 13)
f_alpha = np.trapz(f_psd[alpha_mask], freqs_check[alpha_mask])
p_alpha = np.trapz(p_psd[alpha_mask], freqs_check[alpha_mask])
features['frontal_parietal_alpha_asymmetry'] = (f_alpha - p_alpha) / (f_alpha + p_alpha + 1e-10)
return features
# ── Extract features for all subjects ──
all_features = []
all_labels = []
all_subjects = []
failed = []
for _, row in participants.iterrows():
sub_id = row['participant_id']
group = row['Group']
label = label_map[group]
eeg_file = os.path.join(BASE, sub_id, 'eeg', f'{sub_id}_task-eyesclosed_eeg.set')
if not os.path.exists(eeg_file):
failed.append(sub_id)
continue
try:
raw = mne.io.read_raw_eeglab(eeg_file, preload=True, verbose=False)
# Bandpass filter 0.5-45 Hz
raw.filter(0.5, 45, verbose=False)
data = raw.get_data()
sfreq = raw.info['sfreq']
# Extract all feature groups
feats = {}
feats.update(compute_psd_features(data, sfreq))
feats.update(compute_temporal_features(data))
feats.update(compute_connectivity_features(data, sfreq))
feats.update(compute_regional_features(data, sfreq))
# Add demographic features
feats['age'] = row['Age']
feats['gender'] = 1 if row['Gender'] == 'M' else 0
feats['mmse'] = row['MMSE']
all_features.append(feats)
all_labels.append(label)
all_subjects.append(sub_id)
print(f" {sub_id} [{label_names[label]}] β€” {len(feats)} features extracted")
except Exception as e:
print(f" {sub_id} FAILED: {e}")
failed.append(sub_id)
print(f"\nExtracted: {len(all_features)} subjects")
print(f"Failed: {len(failed)} subjects")
# ── Convert to DataFrame ──
X = pd.DataFrame(all_features)
y = np.array(all_labels)
X = X.fillna(0)
print(f"Feature matrix: {X.shape}")
print(f"Labels: AD={sum(y==0)}, Control={sum(y==1)}, FTD={sum(y==2)}")
# Save features
X.to_csv(os.path.join(OUTPUT_DIR, 'eeg_features.csv'), index=False)
np.save(os.path.join(OUTPUT_DIR, 'eeg_labels.npy'), y)
# ══════════════════════════════════════════════════════════════
# STEP 3: Train classifiers
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 60}")
print("STEP 3: Training classifiers")
print("=" * 60)
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.feature_selection import SelectKBest, f_classif
# ── Feature selection: top 100 features ──
selector = SelectKBest(f_classif, k=min(100, X.shape[1]))
X_selected = selector.fit_transform(X, y)
# Get selected feature names
selected_mask = selector.get_support()
selected_features = X.columns[selected_mask].tolist()
print(f"Selected {len(selected_features)} features from {X.shape[1]}")
# Show top 20 features
scores = selector.scores_[selected_mask]
top_idx = np.argsort(scores)[::-1][:20]
print("\nTop 20 discriminative features:")
for i, idx in enumerate(top_idx):
print(f" {i+1:2d}. {selected_features[idx]:45s} F={scores[idx]:.1f}")
# ── Scale ──
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_selected)
# ── 3-class classification (AD vs Control vs FTD) ──
print(f"\n--- 3-class: AD vs Control vs FTD ---")
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
models = {
'GradientBoosting': GradientBoostingClassifier(
n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42,
),
'RandomForest': RandomForestClassifier(
n_estimators=300, max_depth=6, random_state=42,
),
'ExtraTrees': ExtraTreesClassifier(
n_estimators=300, max_depth=6, random_state=42,
),
'SVM': SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=42),
}
best_model_name = None
best_accuracy = 0
for name, model in models.items():
y_pred = cross_val_predict(model, X_scaled, y, cv=cv)
acc = np.mean(y_pred == y)
print(f"\n{name}: Accuracy = {acc:.1%}")
print(classification_report(y, y_pred, target_names=['AD', 'Control', 'FTD']))
if acc > best_accuracy:
best_accuracy = acc
best_model_name = name
print(f"\nBest 3-class model: {best_model_name} ({best_accuracy:.1%})")
# ── Binary classification: AD vs Control (drop FTD) ──
print(f"\n--- Binary: AD vs Control ---")
binary_mask = y != 2 # drop FTD
X_binary = X_scaled[binary_mask]
y_binary = y[binary_mask]
# Remap: AD=0, Control=1 β†’ AD=1, Control=0 for AUC
y_binary_pos = 1 - y_binary # AD=1, Control=0
cv_binary = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for name, model in models.items():
y_pred = cross_val_predict(model, X_binary, y_binary, cv=cv_binary)
y_prob = cross_val_predict(model, X_binary, y_binary, cv=cv_binary, method='predict_proba')
acc = np.mean(y_pred == y_binary)
# AUC: probability of being AD (class 0)
auc = roc_auc_score(y_binary_pos, y_prob[:, 0])
print(f"\n{name}: Accuracy = {acc:.1%}, AUC = {auc:.3f}")
print(classification_report(y_binary, y_pred, target_names=['AD', 'Control']))
# ══════════════════════════════════════════════════════════════
# STEP 4: Train final model and save
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 60}")
print("STEP 4: Training final models and saving")
print("=" * 60)
# Train on all data for deployment (use best model type)
final_3class = GradientBoostingClassifier(
n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42,
)
final_3class.fit(X_scaled, y)
final_binary = GradientBoostingClassifier(
n_estimators=200, max_depth=4, learning_rate=0.05, random_state=42,
)
final_binary.fit(X_binary, y_binary)
# Save everything
artifacts = {
'model_3class': final_3class,
'model_binary': final_binary,
'scaler': scaler,
'selector': selector,
'feature_names': list(X.columns),
'selected_features': selected_features,
'label_names_3class': {0: 'AD', 1: 'Control', 2: 'FTD'},
'label_names_binary': {0: 'AD', 1: 'Control'},
'channels': CHANNELS,
'bands': BANDS,
'best_cv_accuracy': best_accuracy,
}
model_path = os.path.join(OUTPUT_DIR, 'eeg_ad_classifier.pkl')
with open(model_path, 'wb') as f:
pickle.dump(artifacts, f)
print(f"\nSaved: {model_path} ({os.path.getsize(model_path)/1e6:.1f} MB)")
# ── Feature importance ──
importances = final_3class.feature_importances_
top_features = np.argsort(importances)[::-1][:15]
print("\nTop 15 features (XGBoost importance):")
for i, idx in enumerate(top_features):
print(f" {i+1:2d}. {selected_features[idx]:45s} importance={importances[idx]:.4f}")
# ══════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════
print(f"\n{'=' * 60}")
print("EEG AD CLASSIFIER β€” TRAINING COMPLETE")
print("=" * 60)
print(f" Dataset: OpenNeuro ds004504")
print(f" Subjects: {len(all_features)} ({sum(y==0)} AD, {sum(y==1)} Control, {sum(y==2)} FTD)")
print(f" Features: {X.shape[1]} total β†’ {len(selected_features)} selected")
print(f" Best 3-class CV accuracy: {best_accuracy:.1%}")
print(f" Model saved: {model_path}")
print(f" Author: Satyawan Singh β€” Infonova Solutions")