--- library_name: pytorch tags: - ecg - classification - arrhythmia - ltaf - physionet - rhythm license: mit --- # LTAF ECG Rhythm Classifier — RhythmFromBeats v2 Beat-embedding rhythm classifier for LTAF. Two-stage pipeline: 1. **Frozen HTF beat embedder** (`htf_embedder.pt`, 1.14 M params) trained on **3 corpora** (LTAF + CPSC2021 + AFDB, 1 056 K beats; N 0.977, A 0.935, V 0.907 on LTAF held-out test). 2. **Beat-sequence Transformer head** (`rhythm_classifier.pt`, 893 K params) trained on the per-beat (576-d) features extracted by the embedder for each rhythm bout. Pretrained binary {NSR, AFIB} on 4 corpora (LTAF + CPSC2021 + AFDB + Icentia 200), then fine-tuned to LTAF 6-class (NSR / AFIB / SBR / AB / SVTA / B). 7-view test-time augmentation (TTA-7) hits **macro F1 = 0.767** on the LTAF held-out 9-record test set — **+11 pp over the previous v1 RhythmResNet1D + TTA** (0.658) and **+47 pp over the Chronos-2 frozen baseline** (0.299). ## Why beat embeddings? The previous v1 (a from-scratch 1D-ResNet on raw 10-s windows) was limited by patient-distribution shift in LTAF — beat morphology generalizes well across patients (HTF F1 = 0.94), but rhythm classification on raw signal struggled to factor that out. The v2 pipeline decouples it: the embedder extracts patient-robust beat features, and the rhythm head only learns *sequence patterns* (irregular RR for AFIB, N-V-N-V for B, fast regular for SVTA, etc.) on top. This collapses the patient-shift bottleneck. **SVTA F1 went from 0.15 → 0.71 (+47 pp on the worst class)** between v1 and v2 + TTA-7. ## Classes | Code | Expansion | |------|-----------| | NSR | Normal sinus rhythm | | AFIB | Atrial fibrillation | | SBR | Sinus bradycardia (<60 bpm, sinus origin) | | AB | Atrial bigeminy (every other beat is an APC) | | SVTA | Supraventricular tachyarrhythmia (≥3 consec SV ectopics @ >100 bpm) | | B | Ventricular bigeminy (every other beat is a PVC) | `VT` (31 test windows), `T` (26), `IVR` (1) were dropped — supports too small for stable F1 estimation. ## Test results — LTAF held-out (9 records, 3 716 windows) ### Single-window (no TTA) | Metric | Value | |---|---:| | Accuracy | 0.713 | | Balanced accuracy | 0.805 | | Macro F1 | **0.731** | ### TTA-7 (recommended) | Metric | Value | |---|---:| | Accuracy | 0.760 | | Balanced accuracy | 0.894 | | **Macro F1** | **0.767** | Per-class F1 with TTA-7: | Class | F1 | v1+TTA-7 | Δ | |---|---:|---:|---:| | NSR | 0.754 | 0.756 | -0.2 | | AFIB | 0.759 | 0.619 | **+14.0** | | SBR | 0.759 | 0.821 | -6.2 | | AB | 0.782 | 0.769 | +1.3 | | **SVTA** | **0.715** | 0.148 | **+56.7** | | B | 0.835 | 0.821 | +1.4 | | Macro | **0.767** | 0.658 | **+10.9** | ## Inference ```python import torch from src.models.ts_llm.ecg_beat_htf import EcgBeatHTFClassifier from src.models.ts_llm.rhythm_from_beats import RhythmFromBeats, htf_fused_features device = "cuda" htf = EcgBeatHTFClassifier.load("htf_embedder.pt", device=device).eval() model = RhythmFromBeats.load("rhythm_classifier.pt", device=device).eval() # For each rhythm bout (10 s @ 128 Hz, 2 leads): # 1. find R-peak indices in the bout # 2. for each R-peak, extract a 2 s window (+/- 128 samples) and # build (rr_history, label_history) from preceding K=5 beats # 3. run HTF on the per-beat batch -> (n_beats, 576) features # 4. concat (rr-to-prev, normalized seconds) -> (n_beats, 577) # 5. pass to rhythm head with valid_mask beat_signals = ... # (B, T, 2, 256) per-beat windows, z-scored rr_history = ... # (B, T, 5) RR intervals to preceding 5 beats label_history = ... # (B, T, 5) preceding 5 beat labels (-1 if N/A) rr_extra = ... # (B, T, 1) RR-to-prev in this bout valid_mask = ... # (B, T) True at valid beat positions with torch.no_grad(): flat_sig = beat_signals.view(-1, 2, 256) flat_rr = rr_history.view(-1, 5) flat_lab = label_history.view(-1, 5) feats = htf_fused_features(htf, flat_sig, flat_rr, flat_lab) # (B*T, 576) feats = feats.view(beat_signals.size(0), beat_signals.size(1), -1) logits = model(feats, rr_extra, valid_mask) # (B, 6) pred = logits.argmax(-1) # 0=NSR, 1=AFIB, 2=SBR, 3=AB, 4=SVTA, 5=B ``` For best results, use 7-view TTA at inference — see [`scripts/eval_with_tta_rfb.py`](https://github.com/) which re-rolls the within-bout window offset 7 times and averages softmax. ## Architecture ### HTF embedder (`htf_embedder.pt`, 1 143 939 params) Three parallel streams concatenated into a 576-dim "fused" feature before a 3-class N/A/V head (we discard the head and use the fused feature): - **Time stream** — 1D-CNN trunk on raw (B, 2, 256) 2-s R-peak window. 5 conv blocks, base width 32 → 256, AdaptiveAvgPool1d → 256-d. - **Frequency stream** — same trunk on log-magnitude rFFT (B, 2, 129) → 256-d. - **History stream** — MLP on (5 RR intervals, 5×3 one-hot of preceding-beat labels) → 64-d. Trained on **LTAF + CPSC2021 + AFDB** beats (701 561 N + 184 130 A + 169 946 V), AAMI EC57 mapping. Test F1 on LTAF: 0.939. ### Rhythm head (`rhythm_classifier.pt`, 892 806 params) 4-layer Transformer encoder over the per-beat (576+1 = 577)-d sequence: - **Input projection** Linear(577, 128). - **Positional embedding** learned, max 64 beats. - **4 × _TransformerBlock**: 4-head MHA + 4×128-d MLP with GELU. - **Masked-mean pool** over valid beat positions. - **MLP head** Linear(128, 128) → GELU → Dropout(0.1) → Linear(128, 6). Pretrained binary {NSR, AFIB} on 4 corpora (LTAF train + CPSC2021 + AFDB + Icentia 200, 8 epochs), then fine-tuned 30 epochs on LTAF 6-class (train+val merged, val-stratified checkpointing on 8 held-back records). ## Training Reproducer in [rmxjck/TSLM-Arena](https://github.com/) (commit forthcoming): ```bash # Stage 0: train multi-corpus HTF beats classifier .venv/bin/python scripts/train_ecg_beat_htf_multicorpus.py \ --corpora ltaf cpsc2021 afdb \ --epochs 12 --batch-size 256 --lr 1e-3 \ --output-dir results/ecg_classifier/beats_htf_multicorpus # Stage 1: binary {NSR, AFIB} pretrain on 4 corpora .venv/bin/python scripts/train_rhythm_from_beats_multicorpus.py \ --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ --corpora ltaf cpsc2021 afdb icentia \ --icentia-records 200 --max-bouts-per-record 100 \ --epochs 8 --batch-size 32 --lr 5e-4 \ --classes NSR AFIB \ --output-dir results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain # Stage 2: LTAF 6-class fine-tune .venv/bin/python scripts/train_rhythm_from_beats_finetune.py \ --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ --pretrained results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain/best_classifier.pt \ --window-seconds 10 --max-beats 32 \ --epochs 30 --batch-size 32 --lr 5e-4 \ --use-val-as-train \ --classes NSR AFIB SBR AB SVTA B \ --output-dir results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft # Eval with TTA-7 .venv/bin/python scripts/eval_with_tta_rfb.py \ --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \ --checkpoint results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft/best_classifier.pt \ --n-views 7 \ --output results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft_tta.json ``` Best val macro F1 0.741 at epoch 23. Test single-view 0.731, TTA-7 0.767. Total training time on a single H100: ~25 min for HTF + ~1 min for rhythm pretrain + ~20 min for rhythm fine-tune + ~30 s for TTA eval. ## Ablations | Variant | Test macro F1 | |---|---:| | **v2 + TTA-7** | **0.767** | | v2 single-window | 0.731 | | v2 + v4 ensemble + TTA-7 | 0.756 | | v1 (RhythmResNet1D + TTA-16, no beat embedding) | 0.658 | | v4 (fp32 cache, 6-corpus pretrain w/ mitdb+nsrdb+Icentia 2k) + TTA-7 | 0.711 | | v3 (fp16 cache regression) + ft | 0.684 | | v1 + v2 ensemble + TTA-7 | 0.710 | Notable negative results (didn't help): - **fp16 feature cache**: 5 pp regression vs live HTF (precision loss in the 576-d embeddings hurt the rhythm head's discriminative power). v2 uses live HTF (uncached); fp32 cache (v4) helps with speed but the bigger pretrain corpus skewed toward NSR. - **Multi-model ensembles** (v1+v2, v2+v4): mild regressions because v1 and v4 are weaker baselines. - **Multi-corpus MAE pretrain on raw signal** (v3 transformer): F1 0.34 — beat-level supervision matters more than raw-signal reconstruction at this corpus size. ## Not for clinical use Research artifact only. Not FDA-cleared. Not suitable for triage, diagnosis, or any patient-facing application. ## Citations ```bibtex @misc{petrutiu2008ltafdb, title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans}, author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven}, year = {2008}, howpublished = {PhysioNet}, url = {https://physionet.org/content/ltafdb/} } ``` CPSC 2021, MIT-BIH AFDB and Icentia11k were also used for the multi-corpus stages.