rmxjck commited on
Commit
464d595
·
verified ·
1 Parent(s): 87417d8

Initial release

Browse files
README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ tags:
4
+ - ecg
5
+ - arrhythmia
6
+ - rhythm-classification
7
+ - ltaf
8
+ - physionet
9
+ - 1d-resnet
10
+ license: mit
11
+ datasets:
12
+ - physionet/ltafdb
13
+ ---
14
+
15
+ # LTAF ECG Rhythm Classifier — RhythmResNet1D + TTA
16
+
17
+ A from-scratch 1D-ResNet trained on PhysioNet's
18
+ [Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/)
19
+ database for **6-class rhythm classification** on two-lead 128 Hz ECG.
20
+
21
+ | Metric | Single-window | **+ 7-view TTA (recommended)** |
22
+ |---|---:|---:|
23
+ | Test accuracy | 0.636 | **0.684** |
24
+ | Test balanced accuracy | 0.740 | **0.778** |
25
+ | **Test macro F1** | **0.614** | **0.656** |
26
+
27
+ vs. frozen Chronos-2 + MLP baseline on the same 6-class subset:
28
+ test macro F1 = 0.299 — i.e. **+36 pp / 2.2× the F1**.
29
+
30
+ Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.
31
+
32
+ ## Classes
33
+
34
+ | Code | Expansion |
35
+ |------|-----------|
36
+ | NSR | Normal sinus rhythm |
37
+ | AFIB | Atrial fibrillation |
38
+ | SBR | Sinus bradycardia (<60 bpm, sinus origin) |
39
+ | AB | Atrial bigeminy (every other beat is an APC) |
40
+ | SVTA | Supraventricular tachyarrhythmia (≥3 consec SV ectopics @ >100 bpm) |
41
+ | B | Ventricular bigeminy (every other beat is a PVC) |
42
+
43
+ `VT`, `T`, and `IVR` are excluded — their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.
44
+
45
+ ## Quickstart
46
+
47
+ ```bash
48
+ pip install torch huggingface_hub numpy
49
+ ```
50
+
51
+ ```python
52
+ import numpy as np
53
+ import torch
54
+ from huggingface_hub import hf_hub_download
55
+ from model import RhythmResNet1D, RHYTHM_CLASS_NAMES
56
+
57
+ # Download checkpoint + model code from HF
58
+ ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
59
+ model = RhythmResNet1D.load(ckpt, device="cuda")
60
+ model.eval()
61
+
62
+ # Input: (B, 2, 1280) — 10 s @ 128 Hz, 2 leads, per-channel z-scored.
63
+ x = torch.randn(1, 2, 1280).cuda() # replace with real ECG
64
+ with torch.no_grad():
65
+ logits = model(x)
66
+ pred_idx = logits.argmax(-1).item()
67
+ print(model.class_names[pred_idx])
68
+ ```
69
+
70
+ For best results, use the **7-view TTA** wrapper in `inference.py`
71
+ (averages softmax across 7 random window-start offsets — adds ~4 pp F1
72
+ at the cost of 7× inference compute).
73
+
74
+ ```bash
75
+ python inference.py
76
+ ```
77
+
78
+ ## Architecture
79
+
80
+ `RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64,
81
+ blocks_per_stage=2)`:
82
+
83
+ - **Stem:** Conv1d(2, 64, k=15, stride=2) → BN → ReLU → MaxPool(2).
84
+ - **4 ResNet stages × 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip).
85
+ Channels: 64 → 128 → 256 → 512. Time downsamples 2× at the start of each
86
+ stage past the first.
87
+ - **Head:** AdaptiveAvgPool1d → Linear(512 → 128) → ReLU → Dropout(0.2)
88
+ → Linear(128 → 6).
89
+ - **Total parameters:** 8,794,246.
90
+
91
+ ## Input format
92
+
93
+ - `(B, 2, 1280)` float32
94
+ - 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`)
95
+ - 10 s window
96
+ - Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)`
97
+
98
+ ## Test-time augmentation (TTA)
99
+
100
+ Pass a longer signal slice (≥1280 samples) to `predict_tta()` and it
101
+ samples 7 random 10 s windows, averages the softmax outputs, then
102
+ argmaxes. Why it helps: training uses random window-start sampling
103
+ within each rhythm bout, so the model learns to be invariant to that
104
+ shift. At eval time, taking multiple shifts and averaging cancels the
105
+ position-specific noise. **+4.2 pp test macro F1, no retraining.**
106
+
107
+ ```python
108
+ # (2, 30*128) signal, 30 s long
109
+ cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
110
+ ```
111
+
112
+ ## Training recipe
113
+
114
+ ```bash
115
+ .venv/bin/python scripts/train_ecg_rhythm_scratch.py \
116
+ --arch resnet1d --window-sizes 10 \
117
+ --epochs 30 --batch-size 64 --lr 5e-4 \
118
+ --base-channels 64 \
119
+ --use-val-as-train \
120
+ --classes NSR AFIB SBR AB SVTA B \
121
+ --output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
122
+ ```
123
+
124
+ - Dataset: LTAF train+val combined (75 records). 8 records held out for
125
+ early stopping. Test (9 records, 3,716 windows) untouched.
126
+ - Loss: weighted cross-entropy with sqrt-dampened inverse-frequency
127
+ class weights (cap 10), label smoothing 0.1.
128
+ - Cosine LR schedule from 5e-4 → 0 over 30 epochs. AdamW (wd 1e-4).
129
+ - Best checkpoint by held-out macro F1.
130
+ - Training time on a single H100 80GB: **~6 minutes**.
131
+
132
+ Source repo: `scripts/train_ecg_rhythm_scratch.py` and
133
+ `src/models/ts_llm/ecg_rhythm_scratch.py` in
134
+ [rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena).
135
+
136
+ ## Test set details
137
+
138
+ LTAF held-out split (deterministic seed 42, record-level): 9 records
139
+ (`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows.
140
+
141
+ Confusion matrix (rows = true, cols = pred), with TTA:
142
+
143
+ | | NSR | AFIB | SBR | AB | SVTA | B |
144
+ |-------|----:|-----:|----:|----:|-----:|----:|
145
+ | NSR | 1109 | 286 | 95 | 114 | 185 | 35 |
146
+ | AFIB | 189 | 628 | 25 | 29 | 294 | 14 |
147
+ | SBR | 26 | 0 | 279 | 0 | 0 | 0 |
148
+ | AB | 9 | 13 | 0 | 225 | 3 | 0 |
149
+ | SVTA | 9 | 14 | 0 | 3 | 34 | 0 |
150
+ | B | 4 | 0 | 0 | 1 | 3 | 90 |
151
+
152
+ Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.
153
+
154
+ ## What was tried and didn't help
155
+
156
+ This model was the best of 30+ experiments. What did *not* improve over
157
+ this baseline:
158
+
159
+ - HRV side-channel input (8-dim RR-derived features fused with CNN trunk):
160
+ hurts F1 by 3-8 pp because the CNN already extracts equivalent
161
+ information from raw QRS timing.
162
+ - Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts
163
+ AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model
164
+ toward over-calling AFIB on LTAF's paroxysmal transitions.
165
+ - Wider models (96-channel, 12 M params): overfits.
166
+ - Longer training (50 epochs): overfits.
167
+ - Multi-model soft-voting ensembles: members make correlated errors.
168
+ - Focal loss: matches CE within noise.
169
+ - Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
170
+ - Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M):
171
+ underperform a 2.2 M home-rolled ResNet1D at 12 epochs.
172
+
173
+ ## Not for clinical use
174
+
175
+ Research artifact only. **Not FDA-cleared.** Not suitable for triage,
176
+ diagnosis, or any patient-facing application. Uses the LTAF benchmark
177
+ which has known label noise from its original PhysioNet curation.
178
+
179
+ ## Citation
180
+
181
+ ```bibtex
182
+ @misc{petrutiu2008ltafdb,
183
+ title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
184
+ author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
185
+ year = {2008},
186
+ howpublished = {PhysioNet},
187
+ url = {https://physionet.org/content/ltafdb/}
188
+ }
189
+ ```
__pycache__/model.cpython-312.pyc ADDED
Binary file (6.34 kB). View file
 
best_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e52d7453256a051cbcc5516a1462d9869a5baf8e2e5caeb2d2a0e5b69fa3e961
3
+ size 35256343
inference.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # SPDX-License-Identifier: MIT
3
+ """Inference example for the LTAF ECG rhythm classifier.
4
+
5
+ Two modes:
6
+ - Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor.
7
+ - TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the
8
+ function will pull 7 random 10 s windows from it and soft-vote.
9
+
10
+ Usage:
11
+ .venv/bin/python inference.py
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from pathlib import Path
17
+ from typing import Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ from model import RHYTHM_CLASS_NAMES, RhythmResNet1D
25
+
26
+
27
+ WINDOW_SECONDS = 10
28
+ SOURCE_HZ = 128
29
+ WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ # 1280
30
+
31
+
32
+ def load_model(device: str = "cpu") -> RhythmResNet1D:
33
+ """Download the checkpoint from HF and load it."""
34
+ ckpt_path = hf_hub_download(
35
+ "rmxjck/ltaf-ecg-rhythm-classifier",
36
+ "best_classifier.pt",
37
+ )
38
+ return RhythmResNet1D.load(ckpt_path, device=device)
39
+
40
+
41
+ def zscore(window: np.ndarray) -> np.ndarray:
42
+ """Per-channel z-score a (C, L) array."""
43
+ mean = window.mean(axis=-1, keepdims=True)
44
+ std = window.std(axis=-1, keepdims=True)
45
+ return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False)
46
+
47
+
48
+ def predict_single(
49
+ model: RhythmResNet1D,
50
+ window: np.ndarray,
51
+ device: str = "cpu",
52
+ ) -> Tuple[str, float]:
53
+ """Predict on one (2, 1280) z-scored window. Returns (class_name, prob)."""
54
+ if window.shape != (2, WINDOW_SAMPLES):
55
+ raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}")
56
+ x = torch.from_numpy(window).float().unsqueeze(0).to(device)
57
+ with torch.no_grad():
58
+ probs = F.softmax(model(x), dim=-1)[0]
59
+ idx = int(probs.argmax().item())
60
+ return model.class_names[idx], float(probs[idx].item())
61
+
62
+
63
+ def predict_tta(
64
+ model: RhythmResNet1D,
65
+ long_signal: np.ndarray,
66
+ n_views: int = 7,
67
+ device: str = "cpu",
68
+ seed: int = 42,
69
+ ) -> Tuple[str, float, np.ndarray]:
70
+ """TTA-soft-voting prediction over a longer (2, L) signal.
71
+
72
+ Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280),
73
+ z-scores each independently, runs them through the model, and averages
74
+ the softmax probabilities.
75
+
76
+ Returns (class_name, prob, full_probs) where full_probs is shape (6,).
77
+ """
78
+ n_ch, n_samples = long_signal.shape
79
+ if n_ch != 2:
80
+ raise ValueError(f"Expected 2-channel signal, got {n_ch}")
81
+ if n_samples < WINDOW_SAMPLES:
82
+ raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}")
83
+ rng = np.random.default_rng(seed)
84
+ starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views)
85
+ accum = torch.zeros(model.num_classes, device=device)
86
+ for s in starts:
87
+ window = zscore(long_signal[:, s:s + WINDOW_SAMPLES])
88
+ x = torch.from_numpy(window).float().unsqueeze(0).to(device)
89
+ with torch.no_grad():
90
+ probs = F.softmax(model(x), dim=-1)[0]
91
+ accum += probs
92
+ probs_avg = accum / n_views
93
+ idx = int(probs_avg.argmax().item())
94
+ return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy()
95
+
96
+
97
+ def demo():
98
+ print("Loading model from HF...")
99
+ device = "cuda" if torch.cuda.is_available() else "cpu"
100
+ model = load_model(device)
101
+ print(f"Loaded {model.__class__.__name__} on {device}")
102
+ print(f"Classes: {model.class_names}")
103
+ print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
104
+
105
+ # Synthetic example: random noise (will get garbage prediction).
106
+ print("\n--- single-window demo (random input) ---")
107
+ fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32))
108
+ cls, prob = predict_single(model, fake_window, device=device)
109
+ print(f"prediction: {cls} ({prob:.1%})")
110
+
111
+ print("\n--- TTA-7 demo (random 30 s input) ---")
112
+ fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32)
113
+ cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device)
114
+ print(f"prediction: {cls} ({prob:.1%})")
115
+ print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}")
116
+
117
+
118
+ if __name__ == "__main__":
119
+ demo()
model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: MIT
2
+ """Self-contained RhythmResNet1D for LTAF rhythm classification.
3
+
4
+ Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_rhythm_scratch.py)
5
+ so the model can be loaded with no external project imports.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ RHYTHM_CLASS_NAMES = ["NSR", "AFIB", "SBR", "AB", "SVTA", "B"]
19
+
20
+
21
+ class _BasicBlock1D(nn.Module):
22
+ """Two-conv residual block with optional stride-2 downsample."""
23
+
24
+ def __init__(self, in_c: int, out_c: int, kernel: int = 7, stride: int = 1,
25
+ dropout: float = 0.1):
26
+ super().__init__()
27
+ pad = kernel // 2
28
+ self.conv1 = nn.Conv1d(in_c, out_c, kernel_size=kernel, stride=stride,
29
+ padding=pad, bias=False)
30
+ self.bn1 = nn.BatchNorm1d(out_c)
31
+ self.conv2 = nn.Conv1d(out_c, out_c, kernel_size=kernel, stride=1,
32
+ padding=pad, bias=False)
33
+ self.bn2 = nn.BatchNorm1d(out_c)
34
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
35
+ if stride != 1 or in_c != out_c:
36
+ self.proj = nn.Sequential(
37
+ nn.Conv1d(in_c, out_c, kernel_size=1, stride=stride, bias=False),
38
+ nn.BatchNorm1d(out_c),
39
+ )
40
+ else:
41
+ self.proj = nn.Identity()
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ identity = self.proj(x)
45
+ h = F.relu(self.bn1(self.conv1(x)), inplace=True)
46
+ h = self.drop(h)
47
+ h = self.bn2(self.conv2(h))
48
+ return F.relu(h + identity, inplace=True)
49
+
50
+
51
+ class RhythmResNet1D(nn.Module):
52
+ """1D ResNet — stem + 4 stages, each stage halves time and doubles channels."""
53
+
54
+ def __init__(
55
+ self,
56
+ num_classes: int = 6,
57
+ class_names: List[str] = RHYTHM_CLASS_NAMES,
58
+ n_channels: int = 2,
59
+ base_channels: int = 64,
60
+ blocks_per_stage: int = 2,
61
+ stem_kernel: int = 15,
62
+ block_kernel: int = 7,
63
+ dropout: float = 0.2,
64
+ ):
65
+ super().__init__()
66
+ assert len(class_names) == num_classes
67
+ self.num_classes = num_classes
68
+ self.class_names = list(class_names)
69
+ self.n_channels = n_channels
70
+ self.base_channels = base_channels
71
+ self.blocks_per_stage = blocks_per_stage
72
+
73
+ self.stem = nn.Sequential(
74
+ nn.Conv1d(n_channels, base_channels, kernel_size=stem_kernel,
75
+ stride=2, padding=stem_kernel // 2, bias=False),
76
+ nn.BatchNorm1d(base_channels),
77
+ nn.ReLU(inplace=True),
78
+ nn.MaxPool1d(2),
79
+ )
80
+ stages = []
81
+ in_c = base_channels
82
+ out_c = base_channels
83
+ for s in range(4):
84
+ for b in range(blocks_per_stage):
85
+ stride = 2 if (b == 0 and s > 0) else 1
86
+ stages.append(_BasicBlock1D(in_c, out_c, kernel=block_kernel,
87
+ stride=stride, dropout=dropout))
88
+ in_c = out_c
89
+ out_c = min(out_c * 2, 512)
90
+ self.stages = nn.Sequential(*stages)
91
+ self.pool = nn.AdaptiveAvgPool1d(1)
92
+ self.head = nn.Sequential(
93
+ nn.Linear(in_c, 128),
94
+ nn.ReLU(inplace=True),
95
+ nn.Dropout(dropout),
96
+ nn.Linear(128, num_classes),
97
+ )
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ h = self.stem(x)
101
+ h = self.stages(h)
102
+ feat = self.pool(h).squeeze(-1)
103
+ return self.head(feat)
104
+
105
+ @classmethod
106
+ def load(cls, path: str | Path, device: str = "cpu") -> "RhythmResNet1D":
107
+ ckpt = torch.load(path, map_location=device, weights_only=False)
108
+ model = cls(
109
+ num_classes=ckpt["num_classes"], class_names=ckpt["class_names"],
110
+ n_channels=ckpt["n_channels"],
111
+ base_channels=ckpt.get("base_channels", 64),
112
+ blocks_per_stage=ckpt.get("blocks_per_stage", 2),
113
+ )
114
+ model.load_state_dict(ckpt["state_dict"])
115
+ model.to(device).eval()
116
+ return model
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch>=2.0
2
+ huggingface_hub>=0.20
3
+ numpy>=1.24