rmxjck's picture
Initial release
464d595 verified
---
library_name: pytorch
tags:
- ecg
- arrhythmia
- rhythm-classification
- ltaf
- physionet
- 1d-resnet
license: mit
datasets:
- physionet/ltafdb
---
# LTAF ECG Rhythm Classifier β€” RhythmResNet1D + TTA
A from-scratch 1D-ResNet trained on PhysioNet's
[Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/)
database for **6-class rhythm classification** on two-lead 128 Hz ECG.
| Metric | Single-window | **+ 7-view TTA (recommended)** |
|---|---:|---:|
| Test accuracy | 0.636 | **0.684** |
| Test balanced accuracy | 0.740 | **0.778** |
| **Test macro F1** | **0.614** | **0.656** |
vs. frozen Chronos-2 + MLP baseline on the same 6-class subset:
test macro F1 = 0.299 β€” i.e. **+36 pp / 2.2Γ— the F1**.
Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.
## Classes
| Code | Expansion |
|------|-----------|
| NSR | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
| AB | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β‰₯3 consec SV ectopics @ >100 bpm) |
| B | Ventricular bigeminy (every other beat is a PVC) |
`VT`, `T`, and `IVR` are excluded β€” their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.
## Quickstart
```bash
pip install torch huggingface_hub numpy
```
```python
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES
# Download checkpoint + model code from HF
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
model = RhythmResNet1D.load(ckpt, device="cuda")
model.eval()
# Input: (B, 2, 1280) β€” 10 s @ 128 Hz, 2 leads, per-channel z-scored.
x = torch.randn(1, 2, 1280).cuda() # replace with real ECG
with torch.no_grad():
logits = model(x)
pred_idx = logits.argmax(-1).item()
print(model.class_names[pred_idx])
```
For best results, use the **7-view TTA** wrapper in `inference.py`
(averages softmax across 7 random window-start offsets β€” adds ~4 pp F1
at the cost of 7Γ— inference compute).
```bash
python inference.py
```
## Architecture
`RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64,
blocks_per_stage=2)`:
- **Stem:** Conv1d(2, 64, k=15, stride=2) β†’ BN β†’ ReLU β†’ MaxPool(2).
- **4 ResNet stages Γ— 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip).
Channels: 64 β†’ 128 β†’ 256 β†’ 512. Time downsamples 2Γ— at the start of each
stage past the first.
- **Head:** AdaptiveAvgPool1d β†’ Linear(512 β†’ 128) β†’ ReLU β†’ Dropout(0.2)
β†’ Linear(128 β†’ 6).
- **Total parameters:** 8,794,246.
## Input format
- `(B, 2, 1280)` float32
- 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`)
- 10 s window
- Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)`
## Test-time augmentation (TTA)
Pass a longer signal slice (β‰₯1280 samples) to `predict_tta()` and it
samples 7 random 10 s windows, averages the softmax outputs, then
argmaxes. Why it helps: training uses random window-start sampling
within each rhythm bout, so the model learns to be invariant to that
shift. At eval time, taking multiple shifts and averaging cancels the
position-specific noise. **+4.2 pp test macro F1, no retraining.**
```python
# (2, 30*128) signal, 30 s long
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
```
## Training recipe
```bash
.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
--arch resnet1d --window-sizes 10 \
--epochs 30 --batch-size 64 --lr 5e-4 \
--base-channels 64 \
--use-val-as-train \
--classes NSR AFIB SBR AB SVTA B \
--output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
```
- Dataset: LTAF train+val combined (75 records). 8 records held out for
early stopping. Test (9 records, 3,716 windows) untouched.
- Loss: weighted cross-entropy with sqrt-dampened inverse-frequency
class weights (cap 10), label smoothing 0.1.
- Cosine LR schedule from 5e-4 β†’ 0 over 30 epochs. AdamW (wd 1e-4).
- Best checkpoint by held-out macro F1.
- Training time on a single H100 80GB: **~6 minutes**.
Source repo: `scripts/train_ecg_rhythm_scratch.py` and
`src/models/ts_llm/ecg_rhythm_scratch.py` in
[rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena).
## Test set details
LTAF held-out split (deterministic seed 42, record-level): 9 records
(`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows.
Confusion matrix (rows = true, cols = pred), with TTA:
| | NSR | AFIB | SBR | AB | SVTA | B |
|-------|----:|-----:|----:|----:|-----:|----:|
| NSR | 1109 | 286 | 95 | 114 | 185 | 35 |
| AFIB | 189 | 628 | 25 | 29 | 294 | 14 |
| SBR | 26 | 0 | 279 | 0 | 0 | 0 |
| AB | 9 | 13 | 0 | 225 | 3 | 0 |
| SVTA | 9 | 14 | 0 | 3 | 34 | 0 |
| B | 4 | 0 | 0 | 1 | 3 | 90 |
Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.
## What was tried and didn't help
This model was the best of 30+ experiments. What did *not* improve over
this baseline:
- HRV side-channel input (8-dim RR-derived features fused with CNN trunk):
hurts F1 by 3-8 pp because the CNN already extracts equivalent
information from raw QRS timing.
- Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts
AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model
toward over-calling AFIB on LTAF's paroxysmal transitions.
- Wider models (96-channel, 12 M params): overfits.
- Longer training (50 epochs): overfits.
- Multi-model soft-voting ensembles: members make correlated errors.
- Focal loss: matches CE within noise.
- Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
- Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M):
underperform a 2.2 M home-rolled ResNet1D at 12 epochs.
## Not for clinical use
Research artifact only. **Not FDA-cleared.** Not suitable for triage,
diagnosis, or any patient-facing application. Uses the LTAF benchmark
which has known label noise from its original PhysioNet curation.
## Citation
```bibtex
@misc{petrutiu2008ltafdb,
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
year = {2008},
howpublished = {PhysioNet},
url = {https://physionet.org/content/ltafdb/}
}
```