"""ECG preprocessing for PTB-XL @ 500 Hz (CoRe-ECG / MAE-style pipeline).""" from __future__ import annotations import numpy as np from scipy import signal # PTB-XL 500 Hz records: 10 s × 500 Hz → 5000 samples per lead DEFAULT_FS = 500 # Patch length in **samples** matches the paper (75). At 500 Hz this is 75/500 s per patch; # the paper’s 75 @ 250 Hz is twice as long in **seconds**—we trade that for finer time steps. PATCH_SIZE = 75 # 5000 is not divisible by 75; use the largest length ≤ 5000 that is (66 × 75 = 4950, ~9.9 s). NUM_PATCHES = 66 SIGNAL_LENGTH = NUM_PATCHES * PATCH_SIZE # Paper: 0.65–40 Hz band BANDPASS_LOW_HZ = 0.65 BANDPASS_HIGH_HZ = 40.0 BANDPASS_ORDER = 4 def butter_bandpass(low_hz: float, high_hz: float, fs: float, order: int = BANDPASS_ORDER): nyq = 0.5 * fs lo = max(low_hz / nyq, 1e-6) hi = min(high_hz / nyq, 0.999) if lo >= hi: raise ValueError(f"Invalid band [{low_hz}, {high_hz}] Hz for fs={fs}") return signal.butter(order, [lo, hi], btype="band") def bandpass_filter(ecg: np.ndarray, fs: float = DEFAULT_FS) -> np.ndarray: """ Zero-phase bandpass (Butterworth) on each lead. Parameters ---------- ecg : ndarray, shape (n_leads, n_samples) or (n_samples, n_leads) fs : sampling rate in Hz """ if ecg.ndim != 2: raise ValueError("ecg must be 2-D") # Normalize to (n_leads, n_samples) if ecg.shape[0] > ecg.shape[1]: ecg = ecg.T b, a = butter_bandpass(BANDPASS_LOW_HZ, BANDPASS_HIGH_HZ, fs) out = np.zeros_like(ecg, dtype=np.float64) for c in range(ecg.shape[0]): out[c] = signal.filtfilt(b, a, ecg[c].astype(np.float64)) return out.astype(np.float32) def zscore_per_lead(ecg: np.ndarray, eps: float = 1e-6) -> np.ndarray: """Z-score along time for each lead: shape (C, T).""" if ecg.ndim != 2: raise ValueError("expected (C, T)") mean = ecg.mean(axis=1, keepdims=True) std = ecg.std(axis=1, keepdims=True) return (ecg - mean) / (std + eps) def random_temporal_crop(ecg: np.ndarray, length: int, rng: np.random.Generator) -> np.ndarray: """Crop (C, T) to (C, length) with random start; if T < length, pad with edge values.""" c, t = ecg.shape if t >= length: start = int(rng.integers(0, t - length + 1)) return ecg[:, start : start + length].copy() pad = length - t left = int(rng.integers(0, pad + 1)) return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge") def center_temporal_crop(ecg: np.ndarray, length: int) -> np.ndarray: """Center crop or symmetric pad to length.""" c, t = ecg.shape if t >= length: start = (t - length) // 2 return ecg[:, start : start + length].copy() pad = length - t left = pad // 2 return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge") def preprocess_ecg( ecg_leads_first: np.ndarray, *, fs: float = DEFAULT_FS, training: bool = True, signal_length: int = SIGNAL_LENGTH, rng: np.random.Generator | None = None, ) -> np.ndarray: """ Full preprocessing: bandpass → (optional random crop) → z-score per lead. Parameters ---------- ecg_leads_first : ndarray (12, T_raw) millivolts or arbitrary units from WFDB training : if True, random crop; else center crop """ if rng is None: rng = np.random.default_rng() x = bandpass_filter(ecg_leads_first, fs=fs) if training: x = random_temporal_crop(x, signal_length, rng) else: x = center_temporal_crop(x, signal_length) x = zscore_per_lead(x) return x.astype(np.float32) def ptbxl_wfdb_to_leads_first(p_signal: np.ndarray) -> np.ndarray: """WFDB p_signal is (T, C); return (C, T). Lead order unchanged (I … V6).""" if p_signal.ndim != 2: raise ValueError("p_signal must be 2-D") return np.ascontiguousarray(p_signal.T.astype(np.float32))