| --- |
| library_name: pytorch |
| tags: |
| - ecg |
| - classification |
| - arrhythmia |
| - ltaf |
| - physionet |
| - rhythm |
| license: mit |
| --- |
| |
| # LTAF ECG Rhythm Classifier β RhythmFromBeats v2 |
|
|
| Beat-embedding rhythm classifier for LTAF. Two-stage pipeline: |
|
|
| 1. **Frozen HTF beat embedder** (`htf_embedder.pt`, 1.14 M params) trained |
| on **3 corpora** (LTAF + CPSC2021 + AFDB, 1 056 K beats; N 0.977, |
| A 0.935, V 0.907 on LTAF held-out test). |
| 2. **Beat-sequence Transformer head** (`rhythm_classifier.pt`, 893 K params) |
| trained on the per-beat (576-d) features extracted by the embedder |
| for each rhythm bout. Pretrained binary {NSR, AFIB} on 4 corpora |
| (LTAF + CPSC2021 + AFDB + Icentia 200), then fine-tuned to LTAF |
| 6-class (NSR / AFIB / SBR / AB / SVTA / B). |
|
|
| 7-view test-time augmentation (TTA-7) hits **macro F1 = 0.767** on the |
| LTAF held-out 9-record test set β **+11 pp over the previous v1 |
| RhythmResNet1D + TTA** (0.658) and **+47 pp over the Chronos-2 frozen |
| baseline** (0.299). |
|
|
| ## Why beat embeddings? |
|
|
| The previous v1 (a from-scratch 1D-ResNet on raw 10-s windows) was |
| limited by patient-distribution shift in LTAF β beat morphology |
| generalizes well across patients (HTF F1 = 0.94), but rhythm |
| classification on raw signal struggled to factor that out. |
|
|
| The v2 pipeline decouples it: the embedder extracts patient-robust |
| beat features, and the rhythm head only learns *sequence patterns* |
| (irregular RR for AFIB, N-V-N-V for B, fast regular for SVTA, etc.) on |
| top. This collapses the patient-shift bottleneck. **SVTA F1 went from |
| 0.15 β 0.71 (+47 pp on the worst class)** between v1 and v2 + TTA-7. |
|
|
| ## Classes |
|
|
| | Code | Expansion | |
| |------|-----------| |
| | NSR | Normal sinus rhythm | |
| | AFIB | Atrial fibrillation | |
| | SBR | Sinus bradycardia (<60 bpm, sinus origin) | |
| | AB | Atrial bigeminy (every other beat is an APC) | |
| | SVTA | Supraventricular tachyarrhythmia (β₯3 consec SV ectopics @ >100 bpm) | |
| | B | Ventricular bigeminy (every other beat is a PVC) | |
|
|
| `VT` (31 test windows), `T` (26), `IVR` (1) were dropped β supports too |
| small for stable F1 estimation. |
|
|
| ## Test results β LTAF held-out (9 records, 3 716 windows) |
|
|
| ### Single-window (no TTA) |
|
|
| | Metric | Value | |
| |---|---:| |
| | Accuracy | 0.713 | |
| | Balanced accuracy | 0.805 | |
| | Macro F1 | **0.731** | |
|
|
| ### TTA-7 (recommended) |
|
|
| | Metric | Value | |
| |---|---:| |
| | Accuracy | 0.760 | |
| | Balanced accuracy | 0.894 | |
| | **Macro F1** | **0.767** | |
|
|
| Per-class F1 with TTA-7: |
|
|
| | Class | F1 | v1+TTA-7 | Ξ | |
| |---|---:|---:|---:| |
| | NSR | 0.754 | 0.756 | -0.2 | |
| | AFIB | 0.759 | 0.619 | **+14.0** | |
| | SBR | 0.759 | 0.821 | -6.2 | |
| | AB | 0.782 | 0.769 | +1.3 | |
| | **SVTA** | **0.715** | 0.148 | **+56.7** | |
| | B | 0.835 | 0.821 | +1.4 | |
| | Macro | **0.767** | 0.658 | **+10.9** | |
|
|
| ## Inference |
|
|
| ```python |
| import torch |
| from src.models.ts_llm.ecg_beat_htf import EcgBeatHTFClassifier |
| from src.models.ts_llm.rhythm_from_beats import RhythmFromBeats, htf_fused_features |
| |
| device = "cuda" |
| htf = EcgBeatHTFClassifier.load("htf_embedder.pt", device=device).eval() |
| model = RhythmFromBeats.load("rhythm_classifier.pt", device=device).eval() |
| |
| # For each rhythm bout (10 s @ 128 Hz, 2 leads): |
| # 1. find R-peak indices in the bout |
| # 2. for each R-peak, extract a 2 s window (+/- 128 samples) and |
| # build (rr_history, label_history) from preceding K=5 beats |
| # 3. run HTF on the per-beat batch -> (n_beats, 576) features |
| # 4. concat (rr-to-prev, normalized seconds) -> (n_beats, 577) |
| # 5. pass to rhythm head with valid_mask |
| beat_signals = ... # (B, T, 2, 256) per-beat windows, z-scored |
| rr_history = ... # (B, T, 5) RR intervals to preceding 5 beats |
| label_history = ... # (B, T, 5) preceding 5 beat labels (-1 if N/A) |
| rr_extra = ... # (B, T, 1) RR-to-prev in this bout |
| valid_mask = ... # (B, T) True at valid beat positions |
| |
| with torch.no_grad(): |
| flat_sig = beat_signals.view(-1, 2, 256) |
| flat_rr = rr_history.view(-1, 5) |
| flat_lab = label_history.view(-1, 5) |
| feats = htf_fused_features(htf, flat_sig, flat_rr, flat_lab) # (B*T, 576) |
| feats = feats.view(beat_signals.size(0), beat_signals.size(1), -1) |
| logits = model(feats, rr_extra, valid_mask) # (B, 6) |
| pred = logits.argmax(-1) # 0=NSR, 1=AFIB, 2=SBR, 3=AB, 4=SVTA, 5=B |
| ``` |
|
|
| For best results, use 7-view TTA at inference β see |
| [`scripts/eval_with_tta_rfb.py`](https://github.com/) which |
| re-rolls the within-bout window offset 7 times and averages softmax. |
|
|
| ## Architecture |
|
|
| ### HTF embedder (`htf_embedder.pt`, 1 143 939 params) |
| |
| Three parallel streams concatenated into a 576-dim "fused" feature |
| before a 3-class N/A/V head (we discard the head and use the fused |
| feature): |
| |
| - **Time stream** β 1D-CNN trunk on raw (B, 2, 256) 2-s R-peak window. |
| 5 conv blocks, base width 32 β 256, AdaptiveAvgPool1d β 256-d. |
| - **Frequency stream** β same trunk on log-magnitude rFFT (B, 2, 129) β |
| 256-d. |
| - **History stream** β MLP on (5 RR intervals, 5Γ3 one-hot of |
| preceding-beat labels) β 64-d. |
| |
| Trained on **LTAF + CPSC2021 + AFDB** beats (701 561 N + 184 130 A + |
| 169 946 V), AAMI EC57 mapping. Test F1 on LTAF: 0.939. |
| |
| ### Rhythm head (`rhythm_classifier.pt`, 892 806 params) |
|
|
| 4-layer Transformer encoder over the per-beat (576+1 = 577)-d |
| sequence: |
|
|
| - **Input projection** Linear(577, 128). |
| - **Positional embedding** learned, max 64 beats. |
| - **4 Γ _TransformerBlock**: 4-head MHA + 4Γ128-d MLP with GELU. |
| - **Masked-mean pool** over valid beat positions. |
| - **MLP head** Linear(128, 128) β GELU β Dropout(0.1) β Linear(128, 6). |
| |
| Pretrained binary {NSR, AFIB} on 4 corpora (LTAF train + CPSC2021 + |
| AFDB + Icentia 200, 8 epochs), then fine-tuned 30 epochs on LTAF |
| 6-class (train+val merged, val-stratified checkpointing on 8 held-back |
| records). |
| |
| ## Training |
| |
| Reproducer in [rmxjck/TSLM-Arena](https://github.com/) (commit |
| forthcoming): |
| |
| ```bash |
| # Stage 0: train multi-corpus HTF beats classifier |
| .venv/bin/python scripts/train_ecg_beat_htf_multicorpus.py \ |
| --corpora ltaf cpsc2021 afdb \ |
| --epochs 12 --batch-size 256 --lr 1e-3 \ |
| --output-dir results/ecg_classifier/beats_htf_multicorpus |
| |
| # Stage 1: binary {NSR, AFIB} pretrain on 4 corpora |
| .venv/bin/python scripts/train_rhythm_from_beats_multicorpus.py \ |
| --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ |
| --corpora ltaf cpsc2021 afdb icentia \ |
| --icentia-records 200 --max-bouts-per-record 100 \ |
| --epochs 8 --batch-size 32 --lr 5e-4 \ |
| --classes NSR AFIB \ |
| --output-dir results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain |
| |
| # Stage 2: LTAF 6-class fine-tune |
| .venv/bin/python scripts/train_rhythm_from_beats_finetune.py \ |
| --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ |
| --pretrained results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain/best_classifier.pt \ |
| --window-seconds 10 --max-beats 32 \ |
| --epochs 30 --batch-size 32 --lr 5e-4 \ |
| --use-val-as-train \ |
| --classes NSR AFIB SBR AB SVTA B \ |
| --output-dir results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft |
| |
| # Eval with TTA-7 |
| .venv/bin/python scripts/eval_with_tta_rfb.py \ |
| --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ |
| --checkpoint results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft/best_classifier.pt \ |
| --n-views 7 \ |
| --output results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft_tta.json |
| ``` |
| |
| Best val macro F1 0.741 at epoch 23. Test single-view 0.731, TTA-7 |
| 0.767. Total training time on a single H100: ~25 min for HTF + ~1 min |
| for rhythm pretrain + ~20 min for rhythm fine-tune + ~30 s for TTA eval. |
| |
| ## Ablations |
| |
| | Variant | Test macro F1 | |
| |---|---:| |
| | **v2 + TTA-7** | **0.767** | |
| | v2 single-window | 0.731 | |
| | v2 + v4 ensemble + TTA-7 | 0.756 | |
| | v1 (RhythmResNet1D + TTA-16, no beat embedding) | 0.658 | |
| | v4 (fp32 cache, 6-corpus pretrain w/ mitdb+nsrdb+Icentia 2k) + TTA-7 | 0.711 | |
| | v3 (fp16 cache regression) + ft | 0.684 | |
| | v1 + v2 ensemble + TTA-7 | 0.710 | |
|
|
| Notable negative results (didn't help): |
|
|
| - **fp16 feature cache**: 5 pp regression vs live HTF (precision loss |
| in the 576-d embeddings hurt the rhythm head's discriminative |
| power). v2 uses live HTF (uncached); fp32 cache (v4) helps with |
| speed but the bigger pretrain corpus skewed toward NSR. |
| - **Multi-model ensembles** (v1+v2, v2+v4): mild regressions because |
| v1 and v4 are weaker baselines. |
| - **Multi-corpus MAE pretrain on raw signal** (v3 transformer): F1 |
| 0.34 β beat-level supervision matters more than raw-signal |
| reconstruction at this corpus size. |
|
|
| ## Not for clinical use |
|
|
| Research artifact only. Not FDA-cleared. Not suitable for triage, |
| diagnosis, or any patient-facing application. |
|
|
| ## Citations |
|
|
| ```bibtex |
| @misc{petrutiu2008ltafdb, |
| title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans}, |
| author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven}, |
| year = {2008}, |
| howpublished = {PhysioNet}, |
| url = {https://physionet.org/content/ltafdb/} |
| } |
| ``` |
|
|
| CPSC 2021, MIT-BIH AFDB and Icentia11k were also used for the |
| multi-corpus stages. |
|
|