ecg_reconstruction / preprocessor.py
PhurinutR's picture
followed CoRe-ECG idea
7a63dcf
"""ECG preprocessing for PTB-XL @ 500 Hz (CoRe-ECG / MAE-style pipeline)."""
from __future__ import annotations
import numpy as np
from scipy import signal
# PTB-XL 500 Hz records: 10 s × 500 Hz → 5000 samples per lead
DEFAULT_FS = 500
# Patch length in **samples** matches the paper (75). At 500 Hz this is 75/500 s per patch;
# the paper’s 75 @ 250 Hz is twice as long in **seconds**—we trade that for finer time steps.
PATCH_SIZE = 75
# 5000 is not divisible by 75; use the largest length ≤ 5000 that is (66 × 75 = 4950, ~9.9 s).
NUM_PATCHES = 66
SIGNAL_LENGTH = NUM_PATCHES * PATCH_SIZE
# Paper: 0.65–40 Hz band
BANDPASS_LOW_HZ = 0.65
BANDPASS_HIGH_HZ = 40.0
BANDPASS_ORDER = 4
def butter_bandpass(low_hz: float, high_hz: float, fs: float, order: int = BANDPASS_ORDER):
nyq = 0.5 * fs
lo = max(low_hz / nyq, 1e-6)
hi = min(high_hz / nyq, 0.999)
if lo >= hi:
raise ValueError(f"Invalid band [{low_hz}, {high_hz}] Hz for fs={fs}")
return signal.butter(order, [lo, hi], btype="band")
def bandpass_filter(ecg: np.ndarray, fs: float = DEFAULT_FS) -> np.ndarray:
"""
Zero-phase bandpass (Butterworth) on each lead.
Parameters
----------
ecg : ndarray, shape (n_leads, n_samples) or (n_samples, n_leads)
fs : sampling rate in Hz
"""
if ecg.ndim != 2:
raise ValueError("ecg must be 2-D")
# Normalize to (n_leads, n_samples)
if ecg.shape[0] > ecg.shape[1]:
ecg = ecg.T
b, a = butter_bandpass(BANDPASS_LOW_HZ, BANDPASS_HIGH_HZ, fs)
out = np.zeros_like(ecg, dtype=np.float64)
for c in range(ecg.shape[0]):
out[c] = signal.filtfilt(b, a, ecg[c].astype(np.float64))
return out.astype(np.float32)
def zscore_per_lead(ecg: np.ndarray, eps: float = 1e-6) -> np.ndarray:
"""Z-score along time for each lead: shape (C, T)."""
if ecg.ndim != 2:
raise ValueError("expected (C, T)")
mean = ecg.mean(axis=1, keepdims=True)
std = ecg.std(axis=1, keepdims=True)
return (ecg - mean) / (std + eps)
def random_temporal_crop(ecg: np.ndarray, length: int, rng: np.random.Generator) -> np.ndarray:
"""Crop (C, T) to (C, length) with random start; if T < length, pad with edge values."""
c, t = ecg.shape
if t >= length:
start = int(rng.integers(0, t - length + 1))
return ecg[:, start : start + length].copy()
pad = length - t
left = int(rng.integers(0, pad + 1))
return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge")
def center_temporal_crop(ecg: np.ndarray, length: int) -> np.ndarray:
"""Center crop or symmetric pad to length."""
c, t = ecg.shape
if t >= length:
start = (t - length) // 2
return ecg[:, start : start + length].copy()
pad = length - t
left = pad // 2
return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge")
def preprocess_ecg(
ecg_leads_first: np.ndarray,
*,
fs: float = DEFAULT_FS,
training: bool = True,
signal_length: int = SIGNAL_LENGTH,
rng: np.random.Generator | None = None,
) -> np.ndarray:
"""
Full preprocessing: bandpass → (optional random crop) → z-score per lead.
Parameters
----------
ecg_leads_first : ndarray (12, T_raw) millivolts or arbitrary units from WFDB
training : if True, random crop; else center crop
"""
if rng is None:
rng = np.random.default_rng()
x = bandpass_filter(ecg_leads_first, fs=fs)
if training:
x = random_temporal_crop(x, signal_length, rng)
else:
x = center_temporal_crop(x, signal_length)
x = zscore_per_lead(x)
return x.astype(np.float32)
def ptbxl_wfdb_to_leads_first(p_signal: np.ndarray) -> np.ndarray:
"""WFDB p_signal is (T, C); return (C, T). Lead order unchanged (I … V6)."""
if p_signal.ndim != 2:
raise ValueError("p_signal must be 2-D")
return np.ascontiguousarray(p_signal.T.astype(np.float32))