--- library_name: pytorch tags: - ecg - arrhythmia - rhythm-classification - ltaf - physionet - 1d-resnet license: mit datasets: - physionet/ltafdb --- # LTAF ECG Rhythm Classifier — RhythmResNet1D + TTA A from-scratch 1D-ResNet trained on PhysioNet's [Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/) database for **6-class rhythm classification** on two-lead 128 Hz ECG. | Metric | Single-window | **+ 7-view TTA (recommended)** | |---|---:|---:| | Test accuracy | 0.636 | **0.684** | | Test balanced accuracy | 0.740 | **0.778** | | **Test macro F1** | **0.614** | **0.656** | vs. frozen Chronos-2 + MLP baseline on the same 6-class subset: test macro F1 = 0.299 — i.e. **+36 pp / 2.2× the F1**. Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82. ## Classes | Code | Expansion | |------|-----------| | NSR | Normal sinus rhythm | | AFIB | Atrial fibrillation | | SBR | Sinus bradycardia (<60 bpm, sinus origin) | | AB | Atrial bigeminy (every other beat is an APC) | | SVTA | Supraventricular tachyarrhythmia (≥3 consec SV ectopics @ >100 bpm) | | B | Ventricular bigeminy (every other beat is a PVC) | `VT`, `T`, and `IVR` are excluded — their LTAF test supports (31, 26, 1) are too small for stable F1 estimation. ## Quickstart ```bash pip install torch huggingface_hub numpy ``` ```python import numpy as np import torch from huggingface_hub import hf_hub_download from model import RhythmResNet1D, RHYTHM_CLASS_NAMES # Download checkpoint + model code from HF ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt") model = RhythmResNet1D.load(ckpt, device="cuda") model.eval() # Input: (B, 2, 1280) — 10 s @ 128 Hz, 2 leads, per-channel z-scored. x = torch.randn(1, 2, 1280).cuda() # replace with real ECG with torch.no_grad(): logits = model(x) pred_idx = logits.argmax(-1).item() print(model.class_names[pred_idx]) ``` For best results, use the **7-view TTA** wrapper in `inference.py` (averages softmax across 7 random window-start offsets — adds ~4 pp F1 at the cost of 7× inference compute). ```bash python inference.py ``` ## Architecture `RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64, blocks_per_stage=2)`: - **Stem:** Conv1d(2, 64, k=15, stride=2) → BN → ReLU → MaxPool(2). - **4 ResNet stages × 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip). Channels: 64 → 128 → 256 → 512. Time downsamples 2× at the start of each stage past the first. - **Head:** AdaptiveAvgPool1d → Linear(512 → 128) → ReLU → Dropout(0.2) → Linear(128 → 6). - **Total parameters:** 8,794,246. ## Input format - `(B, 2, 1280)` float32 - 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`) - 10 s window - Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)` ## Test-time augmentation (TTA) Pass a longer signal slice (≥1280 samples) to `predict_tta()` and it samples 7 random 10 s windows, averages the softmax outputs, then argmaxes. Why it helps: training uses random window-start sampling within each rhythm bout, so the model learns to be invariant to that shift. At eval time, taking multiple shifts and averaging cancels the position-specific noise. **+4.2 pp test macro F1, no retraining.** ```python # (2, 30*128) signal, 30 s long cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda") ``` ## Training recipe ```bash .venv/bin/python scripts/train_ecg_rhythm_scratch.py \ --arch resnet1d --window-sizes 10 \ --epochs 30 --batch-size 64 --lr 5e-4 \ --base-channels 64 \ --use-val-as-train \ --classes NSR AFIB SBR AB SVTA B \ --output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide ``` - Dataset: LTAF train+val combined (75 records). 8 records held out for early stopping. Test (9 records, 3,716 windows) untouched. - Loss: weighted cross-entropy with sqrt-dampened inverse-frequency class weights (cap 10), label smoothing 0.1. - Cosine LR schedule from 5e-4 → 0 over 30 epochs. AdamW (wd 1e-4). - Best checkpoint by held-out macro F1. - Training time on a single H100 80GB: **~6 minutes**. Source repo: `scripts/train_ecg_rhythm_scratch.py` and `src/models/ts_llm/ecg_rhythm_scratch.py` in [rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena). ## Test set details LTAF held-out split (deterministic seed 42, record-level): 9 records (`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows. Confusion matrix (rows = true, cols = pred), with TTA: | | NSR | AFIB | SBR | AB | SVTA | B | |-------|----:|-----:|----:|----:|-----:|----:| | NSR | 1109 | 286 | 95 | 114 | 185 | 35 | | AFIB | 189 | 628 | 25 | 29 | 294 | 14 | | SBR | 26 | 0 | 279 | 0 | 0 | 0 | | AB | 9 | 13 | 0 | 225 | 3 | 0 | | SVTA | 9 | 14 | 0 | 3 | 34 | 0 | | B | 4 | 0 | 0 | 1 | 3 | 90 | Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98. ## What was tried and didn't help This model was the best of 30+ experiments. What did *not* improve over this baseline: - HRV side-channel input (8-dim RR-derived features fused with CNN trunk): hurts F1 by 3-8 pp because the CNN already extracts equivalent information from raw QRS timing. - Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model toward over-calling AFIB on LTAF's paroxysmal transitions. - Wider models (96-channel, 12 M params): overfits. - Longer training (50 epochs): overfits. - Multi-model soft-voting ensembles: members make correlated errors. - Focal loss: matches CE within noise. - Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone. - Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M): underperform a 2.2 M home-rolled ResNet1D at 12 epochs. ## Not for clinical use Research artifact only. **Not FDA-cleared.** Not suitable for triage, diagnosis, or any patient-facing application. Uses the LTAF benchmark which has known label noise from its original PhysioNet curation. ## Citation ```bibtex @misc{petrutiu2008ltafdb, title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans}, author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven}, year = {2008}, howpublished = {PhysioNet}, url = {https://physionet.org/content/ltafdb/} } ```