| --- |
| library_name: pytorch |
| tags: |
| - ecg |
| - arrhythmia |
| - rhythm-classification |
| - ltaf |
| - physionet |
| - 1d-resnet |
| license: mit |
| datasets: |
| - physionet/ltafdb |
| --- |
| |
| # LTAF ECG Rhythm Classifier β RhythmResNet1D + TTA |
|
|
| A from-scratch 1D-ResNet trained on PhysioNet's |
| [Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/) |
| database for **6-class rhythm classification** on two-lead 128 Hz ECG. |
|
|
| | Metric | Single-window | **+ 7-view TTA (recommended)** | |
| |---|---:|---:| |
| | Test accuracy | 0.636 | **0.684** | |
| | Test balanced accuracy | 0.740 | **0.778** | |
| | **Test macro F1** | **0.614** | **0.656** | |
|
|
| vs. frozen Chronos-2 + MLP baseline on the same 6-class subset: |
| test macro F1 = 0.299 β i.e. **+36 pp / 2.2Γ the F1**. |
|
|
| Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82. |
|
|
| ## Classes |
|
|
| | Code | Expansion | |
| |------|-----------| |
| | NSR | Normal sinus rhythm | |
| | AFIB | Atrial fibrillation | |
| | SBR | Sinus bradycardia (<60 bpm, sinus origin) | |
| | AB | Atrial bigeminy (every other beat is an APC) | |
| | SVTA | Supraventricular tachyarrhythmia (β₯3 consec SV ectopics @ >100 bpm) | |
| | B | Ventricular bigeminy (every other beat is a PVC) | |
|
|
| `VT`, `T`, and `IVR` are excluded β their LTAF test supports (31, 26, 1) are too small for stable F1 estimation. |
|
|
| ## Quickstart |
|
|
| ```bash |
| pip install torch huggingface_hub numpy |
| ``` |
|
|
| ```python |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from model import RhythmResNet1D, RHYTHM_CLASS_NAMES |
| |
| # Download checkpoint + model code from HF |
| ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt") |
| model = RhythmResNet1D.load(ckpt, device="cuda") |
| model.eval() |
| |
| # Input: (B, 2, 1280) β 10 s @ 128 Hz, 2 leads, per-channel z-scored. |
| x = torch.randn(1, 2, 1280).cuda() # replace with real ECG |
| with torch.no_grad(): |
| logits = model(x) |
| pred_idx = logits.argmax(-1).item() |
| print(model.class_names[pred_idx]) |
| ``` |
|
|
| For best results, use the **7-view TTA** wrapper in `inference.py` |
| (averages softmax across 7 random window-start offsets β adds ~4 pp F1 |
| at the cost of 7Γ inference compute). |
|
|
| ```bash |
| python inference.py |
| ``` |
|
|
| ## Architecture |
|
|
| `RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64, |
| blocks_per_stage=2)`: |
| |
| - **Stem:** Conv1d(2, 64, k=15, stride=2) β BN β ReLU β MaxPool(2). |
| - **4 ResNet stages Γ 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip). |
| Channels: 64 β 128 β 256 β 512. Time downsamples 2Γ at the start of each |
| stage past the first. |
| - **Head:** AdaptiveAvgPool1d β Linear(512 β 128) β ReLU β Dropout(0.2) |
| β Linear(128 β 6). |
| - **Total parameters:** 8,794,246. |
| |
| ## Input format |
| |
| - `(B, 2, 1280)` float32 |
| - 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`) |
| - 10 s window |
| - Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)` |
| |
| ## Test-time augmentation (TTA) |
| |
| Pass a longer signal slice (β₯1280 samples) to `predict_tta()` and it |
| samples 7 random 10 s windows, averages the softmax outputs, then |
| argmaxes. Why it helps: training uses random window-start sampling |
| within each rhythm bout, so the model learns to be invariant to that |
| shift. At eval time, taking multiple shifts and averaging cancels the |
| position-specific noise. **+4.2 pp test macro F1, no retraining.** |
|
|
| ```python |
| # (2, 30*128) signal, 30 s long |
| cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda") |
| ``` |
|
|
| ## Training recipe |
|
|
| ```bash |
| .venv/bin/python scripts/train_ecg_rhythm_scratch.py \ |
| --arch resnet1d --window-sizes 10 \ |
| --epochs 30 --batch-size 64 --lr 5e-4 \ |
| --base-channels 64 \ |
| --use-val-as-train \ |
| --classes NSR AFIB SBR AB SVTA B \ |
| --output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide |
| ``` |
|
|
| - Dataset: LTAF train+val combined (75 records). 8 records held out for |
| early stopping. Test (9 records, 3,716 windows) untouched. |
| - Loss: weighted cross-entropy with sqrt-dampened inverse-frequency |
| class weights (cap 10), label smoothing 0.1. |
| - Cosine LR schedule from 5e-4 β 0 over 30 epochs. AdamW (wd 1e-4). |
| - Best checkpoint by held-out macro F1. |
| - Training time on a single H100 80GB: **~6 minutes**. |
|
|
| Source repo: `scripts/train_ecg_rhythm_scratch.py` and |
| `src/models/ts_llm/ecg_rhythm_scratch.py` in |
| [rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena). |
|
|
| ## Test set details |
|
|
| LTAF held-out split (deterministic seed 42, record-level): 9 records |
| (`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows. |
|
|
| Confusion matrix (rows = true, cols = pred), with TTA: |
|
|
| | | NSR | AFIB | SBR | AB | SVTA | B | |
| |-------|----:|-----:|----:|----:|-----:|----:| |
| | NSR | 1109 | 286 | 95 | 114 | 185 | 35 | |
| | AFIB | 189 | 628 | 25 | 29 | 294 | 14 | |
| | SBR | 26 | 0 | 279 | 0 | 0 | 0 | |
| | AB | 9 | 13 | 0 | 225 | 3 | 0 | |
| | SVTA | 9 | 14 | 0 | 3 | 34 | 0 | |
| | B | 4 | 0 | 0 | 1 | 3 | 90 | |
|
|
| Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98. |
|
|
| ## What was tried and didn't help |
|
|
| This model was the best of 30+ experiments. What did *not* improve over |
| this baseline: |
|
|
| - HRV side-channel input (8-dim RR-derived features fused with CNN trunk): |
| hurts F1 by 3-8 pp because the CNN already extracts equivalent |
| information from raw QRS timing. |
| - Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts |
| AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model |
| toward over-calling AFIB on LTAF's paroxysmal transitions. |
| - Wider models (96-channel, 12 M params): overfits. |
| - Longer training (50 epochs): overfits. |
| - Multi-model soft-voting ensembles: members make correlated errors. |
| - Focal loss: matches CE within noise. |
| - Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone. |
| - Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M): |
| underperform a 2.2 M home-rolled ResNet1D at 12 epochs. |
|
|
| ## Not for clinical use |
|
|
| Research artifact only. **Not FDA-cleared.** Not suitable for triage, |
| diagnosis, or any patient-facing application. Uses the LTAF benchmark |
| which has known label noise from its original PhysioNet curation. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{petrutiu2008ltafdb, |
| title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans}, |
| author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven}, |
| year = {2008}, |
| howpublished = {PhysioNet}, |
| url = {https://physionet.org/content/ltafdb/} |
| } |
| ``` |
|
|