File size: 6,506 Bytes
464d595 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | ---
library_name: pytorch
tags:
- ecg
- arrhythmia
- rhythm-classification
- ltaf
- physionet
- 1d-resnet
license: mit
datasets:
- physionet/ltafdb
---
# LTAF ECG Rhythm Classifier β RhythmResNet1D + TTA
A from-scratch 1D-ResNet trained on PhysioNet's
[Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/)
database for **6-class rhythm classification** on two-lead 128 Hz ECG.
| Metric | Single-window | **+ 7-view TTA (recommended)** |
|---|---:|---:|
| Test accuracy | 0.636 | **0.684** |
| Test balanced accuracy | 0.740 | **0.778** |
| **Test macro F1** | **0.614** | **0.656** |
vs. frozen Chronos-2 + MLP baseline on the same 6-class subset:
test macro F1 = 0.299 β i.e. **+36 pp / 2.2Γ the F1**.
Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.
## Classes
| Code | Expansion |
|------|-----------|
| NSR | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
| AB | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β₯3 consec SV ectopics @ >100 bpm) |
| B | Ventricular bigeminy (every other beat is a PVC) |
`VT`, `T`, and `IVR` are excluded β their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.
## Quickstart
```bash
pip install torch huggingface_hub numpy
```
```python
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES
# Download checkpoint + model code from HF
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
model = RhythmResNet1D.load(ckpt, device="cuda")
model.eval()
# Input: (B, 2, 1280) β 10 s @ 128 Hz, 2 leads, per-channel z-scored.
x = torch.randn(1, 2, 1280).cuda() # replace with real ECG
with torch.no_grad():
logits = model(x)
pred_idx = logits.argmax(-1).item()
print(model.class_names[pred_idx])
```
For best results, use the **7-view TTA** wrapper in `inference.py`
(averages softmax across 7 random window-start offsets β adds ~4 pp F1
at the cost of 7Γ inference compute).
```bash
python inference.py
```
## Architecture
`RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64,
blocks_per_stage=2)`:
- **Stem:** Conv1d(2, 64, k=15, stride=2) β BN β ReLU β MaxPool(2).
- **4 ResNet stages Γ 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip).
Channels: 64 β 128 β 256 β 512. Time downsamples 2Γ at the start of each
stage past the first.
- **Head:** AdaptiveAvgPool1d β Linear(512 β 128) β ReLU β Dropout(0.2)
β Linear(128 β 6).
- **Total parameters:** 8,794,246.
## Input format
- `(B, 2, 1280)` float32
- 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`)
- 10 s window
- Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)`
## Test-time augmentation (TTA)
Pass a longer signal slice (β₯1280 samples) to `predict_tta()` and it
samples 7 random 10 s windows, averages the softmax outputs, then
argmaxes. Why it helps: training uses random window-start sampling
within each rhythm bout, so the model learns to be invariant to that
shift. At eval time, taking multiple shifts and averaging cancels the
position-specific noise. **+4.2 pp test macro F1, no retraining.**
```python
# (2, 30*128) signal, 30 s long
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
```
## Training recipe
```bash
.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
--arch resnet1d --window-sizes 10 \
--epochs 30 --batch-size 64 --lr 5e-4 \
--base-channels 64 \
--use-val-as-train \
--classes NSR AFIB SBR AB SVTA B \
--output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
```
- Dataset: LTAF train+val combined (75 records). 8 records held out for
early stopping. Test (9 records, 3,716 windows) untouched.
- Loss: weighted cross-entropy with sqrt-dampened inverse-frequency
class weights (cap 10), label smoothing 0.1.
- Cosine LR schedule from 5e-4 β 0 over 30 epochs. AdamW (wd 1e-4).
- Best checkpoint by held-out macro F1.
- Training time on a single H100 80GB: **~6 minutes**.
Source repo: `scripts/train_ecg_rhythm_scratch.py` and
`src/models/ts_llm/ecg_rhythm_scratch.py` in
[rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena).
## Test set details
LTAF held-out split (deterministic seed 42, record-level): 9 records
(`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows.
Confusion matrix (rows = true, cols = pred), with TTA:
| | NSR | AFIB | SBR | AB | SVTA | B |
|-------|----:|-----:|----:|----:|-----:|----:|
| NSR | 1109 | 286 | 95 | 114 | 185 | 35 |
| AFIB | 189 | 628 | 25 | 29 | 294 | 14 |
| SBR | 26 | 0 | 279 | 0 | 0 | 0 |
| AB | 9 | 13 | 0 | 225 | 3 | 0 |
| SVTA | 9 | 14 | 0 | 3 | 34 | 0 |
| B | 4 | 0 | 0 | 1 | 3 | 90 |
Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.
## What was tried and didn't help
This model was the best of 30+ experiments. What did *not* improve over
this baseline:
- HRV side-channel input (8-dim RR-derived features fused with CNN trunk):
hurts F1 by 3-8 pp because the CNN already extracts equivalent
information from raw QRS timing.
- Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts
AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model
toward over-calling AFIB on LTAF's paroxysmal transitions.
- Wider models (96-channel, 12 M params): overfits.
- Longer training (50 epochs): overfits.
- Multi-model soft-voting ensembles: members make correlated errors.
- Focal loss: matches CE within noise.
- Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
- Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M):
underperform a 2.2 M home-rolled ResNet1D at 12 epochs.
## Not for clinical use
Research artifact only. **Not FDA-cleared.** Not suitable for triage,
diagnosis, or any patient-facing application. Uses the LTAF benchmark
which has known label noise from its original PhysioNet curation.
## Citation
```bibtex
@misc{petrutiu2008ltafdb,
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
year = {2008},
howpublished = {PhysioNet},
url = {https://physionet.org/content/ltafdb/}
}
```
|