Spaces:
Sleeping
Sleeping
| """ECG preprocessing for PTB-XL @ 500 Hz (CoRe-ECG / MAE-style pipeline).""" | |
| from __future__ import annotations | |
| import numpy as np | |
| from scipy import signal | |
| # PTB-XL 500 Hz records: 10 s × 500 Hz → 5000 samples per lead | |
| DEFAULT_FS = 500 | |
| # Patch length in **samples** matches the paper (75). At 500 Hz this is 75/500 s per patch; | |
| # the paper’s 75 @ 250 Hz is twice as long in **seconds**—we trade that for finer time steps. | |
| PATCH_SIZE = 75 | |
| # 5000 is not divisible by 75; use the largest length ≤ 5000 that is (66 × 75 = 4950, ~9.9 s). | |
| NUM_PATCHES = 66 | |
| SIGNAL_LENGTH = NUM_PATCHES * PATCH_SIZE | |
| # Paper: 0.65–40 Hz band | |
| BANDPASS_LOW_HZ = 0.65 | |
| BANDPASS_HIGH_HZ = 40.0 | |
| BANDPASS_ORDER = 4 | |
| def butter_bandpass(low_hz: float, high_hz: float, fs: float, order: int = BANDPASS_ORDER): | |
| nyq = 0.5 * fs | |
| lo = max(low_hz / nyq, 1e-6) | |
| hi = min(high_hz / nyq, 0.999) | |
| if lo >= hi: | |
| raise ValueError(f"Invalid band [{low_hz}, {high_hz}] Hz for fs={fs}") | |
| return signal.butter(order, [lo, hi], btype="band") | |
| def bandpass_filter(ecg: np.ndarray, fs: float = DEFAULT_FS) -> np.ndarray: | |
| """ | |
| Zero-phase bandpass (Butterworth) on each lead. | |
| Parameters | |
| ---------- | |
| ecg : ndarray, shape (n_leads, n_samples) or (n_samples, n_leads) | |
| fs : sampling rate in Hz | |
| """ | |
| if ecg.ndim != 2: | |
| raise ValueError("ecg must be 2-D") | |
| # Normalize to (n_leads, n_samples) | |
| if ecg.shape[0] > ecg.shape[1]: | |
| ecg = ecg.T | |
| b, a = butter_bandpass(BANDPASS_LOW_HZ, BANDPASS_HIGH_HZ, fs) | |
| out = np.zeros_like(ecg, dtype=np.float64) | |
| for c in range(ecg.shape[0]): | |
| out[c] = signal.filtfilt(b, a, ecg[c].astype(np.float64)) | |
| return out.astype(np.float32) | |
| def zscore_per_lead(ecg: np.ndarray, eps: float = 1e-6) -> np.ndarray: | |
| """Z-score along time for each lead: shape (C, T).""" | |
| if ecg.ndim != 2: | |
| raise ValueError("expected (C, T)") | |
| mean = ecg.mean(axis=1, keepdims=True) | |
| std = ecg.std(axis=1, keepdims=True) | |
| return (ecg - mean) / (std + eps) | |
| def random_temporal_crop(ecg: np.ndarray, length: int, rng: np.random.Generator) -> np.ndarray: | |
| """Crop (C, T) to (C, length) with random start; if T < length, pad with edge values.""" | |
| c, t = ecg.shape | |
| if t >= length: | |
| start = int(rng.integers(0, t - length + 1)) | |
| return ecg[:, start : start + length].copy() | |
| pad = length - t | |
| left = int(rng.integers(0, pad + 1)) | |
| return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge") | |
| def center_temporal_crop(ecg: np.ndarray, length: int) -> np.ndarray: | |
| """Center crop or symmetric pad to length.""" | |
| c, t = ecg.shape | |
| if t >= length: | |
| start = (t - length) // 2 | |
| return ecg[:, start : start + length].copy() | |
| pad = length - t | |
| left = pad // 2 | |
| return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge") | |
| def preprocess_ecg( | |
| ecg_leads_first: np.ndarray, | |
| *, | |
| fs: float = DEFAULT_FS, | |
| training: bool = True, | |
| signal_length: int = SIGNAL_LENGTH, | |
| rng: np.random.Generator | None = None, | |
| ) -> np.ndarray: | |
| """ | |
| Full preprocessing: bandpass → (optional random crop) → z-score per lead. | |
| Parameters | |
| ---------- | |
| ecg_leads_first : ndarray (12, T_raw) millivolts or arbitrary units from WFDB | |
| training : if True, random crop; else center crop | |
| """ | |
| if rng is None: | |
| rng = np.random.default_rng() | |
| x = bandpass_filter(ecg_leads_first, fs=fs) | |
| if training: | |
| x = random_temporal_crop(x, signal_length, rng) | |
| else: | |
| x = center_temporal_crop(x, signal_length) | |
| x = zscore_per_lead(x) | |
| return x.astype(np.float32) | |
| def ptbxl_wfdb_to_leads_first(p_signal: np.ndarray) -> np.ndarray: | |
| """WFDB p_signal is (T, C); return (C, T). Lead order unchanged (I … V6).""" | |
| if p_signal.ndim != 2: | |
| raise ValueError("p_signal must be 2-D") | |
| return np.ascontiguousarray(p_signal.T.astype(np.float32)) | |