File size: 3,930 Bytes
7a63dcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""ECG preprocessing for PTB-XL @ 500 Hz (CoRe-ECG / MAE-style pipeline)."""

from __future__ import annotations

import numpy as np
from scipy import signal

# PTB-XL 500 Hz records: 10 s × 500 Hz → 5000 samples per lead
DEFAULT_FS = 500
# Patch length in **samples** matches the paper (75). At 500 Hz this is 75/500 s per patch;
# the paper’s 75 @ 250 Hz is twice as long in **seconds**—we trade that for finer time steps.
PATCH_SIZE = 75
# 5000 is not divisible by 75; use the largest length ≤ 5000 that is (66 × 75 = 4950, ~9.9 s).
NUM_PATCHES = 66
SIGNAL_LENGTH = NUM_PATCHES * PATCH_SIZE
# Paper: 0.65–40 Hz band
BANDPASS_LOW_HZ = 0.65
BANDPASS_HIGH_HZ = 40.0
BANDPASS_ORDER = 4


def butter_bandpass(low_hz: float, high_hz: float, fs: float, order: int = BANDPASS_ORDER):
    nyq = 0.5 * fs
    lo = max(low_hz / nyq, 1e-6)
    hi = min(high_hz / nyq, 0.999)
    if lo >= hi:
        raise ValueError(f"Invalid band [{low_hz}, {high_hz}] Hz for fs={fs}")
    return signal.butter(order, [lo, hi], btype="band")


def bandpass_filter(ecg: np.ndarray, fs: float = DEFAULT_FS) -> np.ndarray:
    """
    Zero-phase bandpass (Butterworth) on each lead.

    Parameters
    ----------
    ecg : ndarray, shape (n_leads, n_samples) or (n_samples, n_leads)
    fs : sampling rate in Hz
    """
    if ecg.ndim != 2:
        raise ValueError("ecg must be 2-D")
    # Normalize to (n_leads, n_samples)
    if ecg.shape[0] > ecg.shape[1]:
        ecg = ecg.T
    b, a = butter_bandpass(BANDPASS_LOW_HZ, BANDPASS_HIGH_HZ, fs)
    out = np.zeros_like(ecg, dtype=np.float64)
    for c in range(ecg.shape[0]):
        out[c] = signal.filtfilt(b, a, ecg[c].astype(np.float64))
    return out.astype(np.float32)


def zscore_per_lead(ecg: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Z-score along time for each lead: shape (C, T)."""
    if ecg.ndim != 2:
        raise ValueError("expected (C, T)")
    mean = ecg.mean(axis=1, keepdims=True)
    std = ecg.std(axis=1, keepdims=True)
    return (ecg - mean) / (std + eps)


def random_temporal_crop(ecg: np.ndarray, length: int, rng: np.random.Generator) -> np.ndarray:
    """Crop (C, T) to (C, length) with random start; if T < length, pad with edge values."""
    c, t = ecg.shape
    if t >= length:
        start = int(rng.integers(0, t - length + 1))
        return ecg[:, start : start + length].copy()
    pad = length - t
    left = int(rng.integers(0, pad + 1))
    return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge")


def center_temporal_crop(ecg: np.ndarray, length: int) -> np.ndarray:
    """Center crop or symmetric pad to length."""
    c, t = ecg.shape
    if t >= length:
        start = (t - length) // 2
        return ecg[:, start : start + length].copy()
    pad = length - t
    left = pad // 2
    return np.pad(ecg, ((0, 0), (left, pad - left)), mode="edge")


def preprocess_ecg(
    ecg_leads_first: np.ndarray,
    *,
    fs: float = DEFAULT_FS,
    training: bool = True,
    signal_length: int = SIGNAL_LENGTH,
    rng: np.random.Generator | None = None,
) -> np.ndarray:
    """
    Full preprocessing: bandpass → (optional random crop) → z-score per lead.

    Parameters
    ----------
    ecg_leads_first : ndarray (12, T_raw) millivolts or arbitrary units from WFDB
    training : if True, random crop; else center crop
    """
    if rng is None:
        rng = np.random.default_rng()
    x = bandpass_filter(ecg_leads_first, fs=fs)
    if training:
        x = random_temporal_crop(x, signal_length, rng)
    else:
        x = center_temporal_crop(x, signal_length)
    x = zscore_per_lead(x)
    return x.astype(np.float32)


def ptbxl_wfdb_to_leads_first(p_signal: np.ndarray) -> np.ndarray:
    """WFDB p_signal is (T, C); return (C, T). Lead order unchanged (I … V6)."""
    if p_signal.ndim != 2:
        raise ValueError("p_signal must be 2-D")
    return np.ascontiguousarray(p_signal.T.astype(np.float32))