File size: 9,187 Bytes
4d0d75d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | ---
library_name: pytorch
tags:
- ecg
- classification
- arrhythmia
- ltaf
- physionet
- rhythm
license: mit
---
# LTAF ECG Rhythm Classifier β RhythmFromBeats v2
Beat-embedding rhythm classifier for LTAF. Two-stage pipeline:
1. **Frozen HTF beat embedder** (`htf_embedder.pt`, 1.14 M params) trained
on **3 corpora** (LTAF + CPSC2021 + AFDB, 1 056 K beats; N 0.977,
A 0.935, V 0.907 on LTAF held-out test).
2. **Beat-sequence Transformer head** (`rhythm_classifier.pt`, 893 K params)
trained on the per-beat (576-d) features extracted by the embedder
for each rhythm bout. Pretrained binary {NSR, AFIB} on 4 corpora
(LTAF + CPSC2021 + AFDB + Icentia 200), then fine-tuned to LTAF
6-class (NSR / AFIB / SBR / AB / SVTA / B).
7-view test-time augmentation (TTA-7) hits **macro F1 = 0.767** on the
LTAF held-out 9-record test set β **+11 pp over the previous v1
RhythmResNet1D + TTA** (0.658) and **+47 pp over the Chronos-2 frozen
baseline** (0.299).
## Why beat embeddings?
The previous v1 (a from-scratch 1D-ResNet on raw 10-s windows) was
limited by patient-distribution shift in LTAF β beat morphology
generalizes well across patients (HTF F1 = 0.94), but rhythm
classification on raw signal struggled to factor that out.
The v2 pipeline decouples it: the embedder extracts patient-robust
beat features, and the rhythm head only learns *sequence patterns*
(irregular RR for AFIB, N-V-N-V for B, fast regular for SVTA, etc.) on
top. This collapses the patient-shift bottleneck. **SVTA F1 went from
0.15 β 0.71 (+47 pp on the worst class)** between v1 and v2 + TTA-7.
## Classes
| Code | Expansion |
|------|-----------|
| NSR | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
| AB | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β₯3 consec SV ectopics @ >100 bpm) |
| B | Ventricular bigeminy (every other beat is a PVC) |
`VT` (31 test windows), `T` (26), `IVR` (1) were dropped β supports too
small for stable F1 estimation.
## Test results β LTAF held-out (9 records, 3 716 windows)
### Single-window (no TTA)
| Metric | Value |
|---|---:|
| Accuracy | 0.713 |
| Balanced accuracy | 0.805 |
| Macro F1 | **0.731** |
### TTA-7 (recommended)
| Metric | Value |
|---|---:|
| Accuracy | 0.760 |
| Balanced accuracy | 0.894 |
| **Macro F1** | **0.767** |
Per-class F1 with TTA-7:
| Class | F1 | v1+TTA-7 | Ξ |
|---|---:|---:|---:|
| NSR | 0.754 | 0.756 | -0.2 |
| AFIB | 0.759 | 0.619 | **+14.0** |
| SBR | 0.759 | 0.821 | -6.2 |
| AB | 0.782 | 0.769 | +1.3 |
| **SVTA** | **0.715** | 0.148 | **+56.7** |
| B | 0.835 | 0.821 | +1.4 |
| Macro | **0.767** | 0.658 | **+10.9** |
## Inference
```python
import torch
from src.models.ts_llm.ecg_beat_htf import EcgBeatHTFClassifier
from src.models.ts_llm.rhythm_from_beats import RhythmFromBeats, htf_fused_features
device = "cuda"
htf = EcgBeatHTFClassifier.load("htf_embedder.pt", device=device).eval()
model = RhythmFromBeats.load("rhythm_classifier.pt", device=device).eval()
# For each rhythm bout (10 s @ 128 Hz, 2 leads):
# 1. find R-peak indices in the bout
# 2. for each R-peak, extract a 2 s window (+/- 128 samples) and
# build (rr_history, label_history) from preceding K=5 beats
# 3. run HTF on the per-beat batch -> (n_beats, 576) features
# 4. concat (rr-to-prev, normalized seconds) -> (n_beats, 577)
# 5. pass to rhythm head with valid_mask
beat_signals = ... # (B, T, 2, 256) per-beat windows, z-scored
rr_history = ... # (B, T, 5) RR intervals to preceding 5 beats
label_history = ... # (B, T, 5) preceding 5 beat labels (-1 if N/A)
rr_extra = ... # (B, T, 1) RR-to-prev in this bout
valid_mask = ... # (B, T) True at valid beat positions
with torch.no_grad():
flat_sig = beat_signals.view(-1, 2, 256)
flat_rr = rr_history.view(-1, 5)
flat_lab = label_history.view(-1, 5)
feats = htf_fused_features(htf, flat_sig, flat_rr, flat_lab) # (B*T, 576)
feats = feats.view(beat_signals.size(0), beat_signals.size(1), -1)
logits = model(feats, rr_extra, valid_mask) # (B, 6)
pred = logits.argmax(-1) # 0=NSR, 1=AFIB, 2=SBR, 3=AB, 4=SVTA, 5=B
```
For best results, use 7-view TTA at inference β see
[`scripts/eval_with_tta_rfb.py`](https://github.com/) which
re-rolls the within-bout window offset 7 times and averages softmax.
## Architecture
### HTF embedder (`htf_embedder.pt`, 1 143 939 params)
Three parallel streams concatenated into a 576-dim "fused" feature
before a 3-class N/A/V head (we discard the head and use the fused
feature):
- **Time stream** β 1D-CNN trunk on raw (B, 2, 256) 2-s R-peak window.
5 conv blocks, base width 32 β 256, AdaptiveAvgPool1d β 256-d.
- **Frequency stream** β same trunk on log-magnitude rFFT (B, 2, 129) β
256-d.
- **History stream** β MLP on (5 RR intervals, 5Γ3 one-hot of
preceding-beat labels) β 64-d.
Trained on **LTAF + CPSC2021 + AFDB** beats (701 561 N + 184 130 A +
169 946 V), AAMI EC57 mapping. Test F1 on LTAF: 0.939.
### Rhythm head (`rhythm_classifier.pt`, 892 806 params)
4-layer Transformer encoder over the per-beat (576+1 = 577)-d
sequence:
- **Input projection** Linear(577, 128).
- **Positional embedding** learned, max 64 beats.
- **4 Γ _TransformerBlock**: 4-head MHA + 4Γ128-d MLP with GELU.
- **Masked-mean pool** over valid beat positions.
- **MLP head** Linear(128, 128) β GELU β Dropout(0.1) β Linear(128, 6).
Pretrained binary {NSR, AFIB} on 4 corpora (LTAF train + CPSC2021 +
AFDB + Icentia 200, 8 epochs), then fine-tuned 30 epochs on LTAF
6-class (train+val merged, val-stratified checkpointing on 8 held-back
records).
## Training
Reproducer in [rmxjck/TSLM-Arena](https://github.com/) (commit
forthcoming):
```bash
# Stage 0: train multi-corpus HTF beats classifier
.venv/bin/python scripts/train_ecg_beat_htf_multicorpus.py \
--corpora ltaf cpsc2021 afdb \
--epochs 12 --batch-size 256 --lr 1e-3 \
--output-dir results/ecg_classifier/beats_htf_multicorpus
# Stage 1: binary {NSR, AFIB} pretrain on 4 corpora
.venv/bin/python scripts/train_rhythm_from_beats_multicorpus.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--corpora ltaf cpsc2021 afdb icentia \
--icentia-records 200 --max-bouts-per-record 100 \
--epochs 8 --batch-size 32 --lr 5e-4 \
--classes NSR AFIB \
--output-dir results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain
# Stage 2: LTAF 6-class fine-tune
.venv/bin/python scripts/train_rhythm_from_beats_finetune.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--pretrained results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain/best_classifier.pt \
--window-seconds 10 --max-beats 32 \
--epochs 30 --batch-size 32 --lr 5e-4 \
--use-val-as-train \
--classes NSR AFIB SBR AB SVTA B \
--output-dir results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft
# Eval with TTA-7
.venv/bin/python scripts/eval_with_tta_rfb.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--checkpoint results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft/best_classifier.pt \
--n-views 7 \
--output results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft_tta.json
```
Best val macro F1 0.741 at epoch 23. Test single-view 0.731, TTA-7
0.767. Total training time on a single H100: ~25 min for HTF + ~1 min
for rhythm pretrain + ~20 min for rhythm fine-tune + ~30 s for TTA eval.
## Ablations
| Variant | Test macro F1 |
|---|---:|
| **v2 + TTA-7** | **0.767** |
| v2 single-window | 0.731 |
| v2 + v4 ensemble + TTA-7 | 0.756 |
| v1 (RhythmResNet1D + TTA-16, no beat embedding) | 0.658 |
| v4 (fp32 cache, 6-corpus pretrain w/ mitdb+nsrdb+Icentia 2k) + TTA-7 | 0.711 |
| v3 (fp16 cache regression) + ft | 0.684 |
| v1 + v2 ensemble + TTA-7 | 0.710 |
Notable negative results (didn't help):
- **fp16 feature cache**: 5 pp regression vs live HTF (precision loss
in the 576-d embeddings hurt the rhythm head's discriminative
power). v2 uses live HTF (uncached); fp32 cache (v4) helps with
speed but the bigger pretrain corpus skewed toward NSR.
- **Multi-model ensembles** (v1+v2, v2+v4): mild regressions because
v1 and v4 are weaker baselines.
- **Multi-corpus MAE pretrain on raw signal** (v3 transformer): F1
0.34 β beat-level supervision matters more than raw-signal
reconstruction at this corpus size.
## Not for clinical use
Research artifact only. Not FDA-cleared. Not suitable for triage,
diagnosis, or any patient-facing application.
## Citations
```bibtex
@misc{petrutiu2008ltafdb,
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
year = {2008},
howpublished = {PhysioNet},
url = {https://physionet.org/content/ltafdb/}
}
```
CPSC 2021, MIT-BIH AFDB and Icentia11k were also used for the
multi-corpus stages.
|