rmxjck's picture
Initial v2 release: RhythmFromBeats + multi-corpus HTF, F1=0.767
4d0d75d verified
---
library_name: pytorch
tags:
- ecg
- classification
- arrhythmia
- ltaf
- physionet
- rhythm
license: mit
---
# LTAF ECG Rhythm Classifier β€” RhythmFromBeats v2
Beat-embedding rhythm classifier for LTAF. Two-stage pipeline:
1. **Frozen HTF beat embedder** (`htf_embedder.pt`, 1.14 M params) trained
on **3 corpora** (LTAF + CPSC2021 + AFDB, 1 056 K beats; N 0.977,
A 0.935, V 0.907 on LTAF held-out test).
2. **Beat-sequence Transformer head** (`rhythm_classifier.pt`, 893 K params)
trained on the per-beat (576-d) features extracted by the embedder
for each rhythm bout. Pretrained binary {NSR, AFIB} on 4 corpora
(LTAF + CPSC2021 + AFDB + Icentia 200), then fine-tuned to LTAF
6-class (NSR / AFIB / SBR / AB / SVTA / B).
7-view test-time augmentation (TTA-7) hits **macro F1 = 0.767** on the
LTAF held-out 9-record test set β€” **+11 pp over the previous v1
RhythmResNet1D + TTA** (0.658) and **+47 pp over the Chronos-2 frozen
baseline** (0.299).
## Why beat embeddings?
The previous v1 (a from-scratch 1D-ResNet on raw 10-s windows) was
limited by patient-distribution shift in LTAF β€” beat morphology
generalizes well across patients (HTF F1 = 0.94), but rhythm
classification on raw signal struggled to factor that out.
The v2 pipeline decouples it: the embedder extracts patient-robust
beat features, and the rhythm head only learns *sequence patterns*
(irregular RR for AFIB, N-V-N-V for B, fast regular for SVTA, etc.) on
top. This collapses the patient-shift bottleneck. **SVTA F1 went from
0.15 β†’ 0.71 (+47 pp on the worst class)** between v1 and v2 + TTA-7.
## Classes
| Code | Expansion |
|------|-----------|
| NSR | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
| AB | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β‰₯3 consec SV ectopics @ >100 bpm) |
| B | Ventricular bigeminy (every other beat is a PVC) |
`VT` (31 test windows), `T` (26), `IVR` (1) were dropped β€” supports too
small for stable F1 estimation.
## Test results β€” LTAF held-out (9 records, 3 716 windows)
### Single-window (no TTA)
| Metric | Value |
|---|---:|
| Accuracy | 0.713 |
| Balanced accuracy | 0.805 |
| Macro F1 | **0.731** |
### TTA-7 (recommended)
| Metric | Value |
|---|---:|
| Accuracy | 0.760 |
| Balanced accuracy | 0.894 |
| **Macro F1** | **0.767** |
Per-class F1 with TTA-7:
| Class | F1 | v1+TTA-7 | Ξ” |
|---|---:|---:|---:|
| NSR | 0.754 | 0.756 | -0.2 |
| AFIB | 0.759 | 0.619 | **+14.0** |
| SBR | 0.759 | 0.821 | -6.2 |
| AB | 0.782 | 0.769 | +1.3 |
| **SVTA** | **0.715** | 0.148 | **+56.7** |
| B | 0.835 | 0.821 | +1.4 |
| Macro | **0.767** | 0.658 | **+10.9** |
## Inference
```python
import torch
from src.models.ts_llm.ecg_beat_htf import EcgBeatHTFClassifier
from src.models.ts_llm.rhythm_from_beats import RhythmFromBeats, htf_fused_features
device = "cuda"
htf = EcgBeatHTFClassifier.load("htf_embedder.pt", device=device).eval()
model = RhythmFromBeats.load("rhythm_classifier.pt", device=device).eval()
# For each rhythm bout (10 s @ 128 Hz, 2 leads):
# 1. find R-peak indices in the bout
# 2. for each R-peak, extract a 2 s window (+/- 128 samples) and
# build (rr_history, label_history) from preceding K=5 beats
# 3. run HTF on the per-beat batch -> (n_beats, 576) features
# 4. concat (rr-to-prev, normalized seconds) -> (n_beats, 577)
# 5. pass to rhythm head with valid_mask
beat_signals = ... # (B, T, 2, 256) per-beat windows, z-scored
rr_history = ... # (B, T, 5) RR intervals to preceding 5 beats
label_history = ... # (B, T, 5) preceding 5 beat labels (-1 if N/A)
rr_extra = ... # (B, T, 1) RR-to-prev in this bout
valid_mask = ... # (B, T) True at valid beat positions
with torch.no_grad():
flat_sig = beat_signals.view(-1, 2, 256)
flat_rr = rr_history.view(-1, 5)
flat_lab = label_history.view(-1, 5)
feats = htf_fused_features(htf, flat_sig, flat_rr, flat_lab) # (B*T, 576)
feats = feats.view(beat_signals.size(0), beat_signals.size(1), -1)
logits = model(feats, rr_extra, valid_mask) # (B, 6)
pred = logits.argmax(-1) # 0=NSR, 1=AFIB, 2=SBR, 3=AB, 4=SVTA, 5=B
```
For best results, use 7-view TTA at inference β€” see
[`scripts/eval_with_tta_rfb.py`](https://github.com/) which
re-rolls the within-bout window offset 7 times and averages softmax.
## Architecture
### HTF embedder (`htf_embedder.pt`, 1 143 939 params)
Three parallel streams concatenated into a 576-dim "fused" feature
before a 3-class N/A/V head (we discard the head and use the fused
feature):
- **Time stream** β€” 1D-CNN trunk on raw (B, 2, 256) 2-s R-peak window.
5 conv blocks, base width 32 β†’ 256, AdaptiveAvgPool1d β†’ 256-d.
- **Frequency stream** β€” same trunk on log-magnitude rFFT (B, 2, 129) β†’
256-d.
- **History stream** β€” MLP on (5 RR intervals, 5Γ—3 one-hot of
preceding-beat labels) β†’ 64-d.
Trained on **LTAF + CPSC2021 + AFDB** beats (701 561 N + 184 130 A +
169 946 V), AAMI EC57 mapping. Test F1 on LTAF: 0.939.
### Rhythm head (`rhythm_classifier.pt`, 892 806 params)
4-layer Transformer encoder over the per-beat (576+1 = 577)-d
sequence:
- **Input projection** Linear(577, 128).
- **Positional embedding** learned, max 64 beats.
- **4 Γ— _TransformerBlock**: 4-head MHA + 4Γ—128-d MLP with GELU.
- **Masked-mean pool** over valid beat positions.
- **MLP head** Linear(128, 128) β†’ GELU β†’ Dropout(0.1) β†’ Linear(128, 6).
Pretrained binary {NSR, AFIB} on 4 corpora (LTAF train + CPSC2021 +
AFDB + Icentia 200, 8 epochs), then fine-tuned 30 epochs on LTAF
6-class (train+val merged, val-stratified checkpointing on 8 held-back
records).
## Training
Reproducer in [rmxjck/TSLM-Arena](https://github.com/) (commit
forthcoming):
```bash
# Stage 0: train multi-corpus HTF beats classifier
.venv/bin/python scripts/train_ecg_beat_htf_multicorpus.py \
--corpora ltaf cpsc2021 afdb \
--epochs 12 --batch-size 256 --lr 1e-3 \
--output-dir results/ecg_classifier/beats_htf_multicorpus
# Stage 1: binary {NSR, AFIB} pretrain on 4 corpora
.venv/bin/python scripts/train_rhythm_from_beats_multicorpus.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--corpora ltaf cpsc2021 afdb icentia \
--icentia-records 200 --max-bouts-per-record 100 \
--epochs 8 --batch-size 32 --lr 5e-4 \
--classes NSR AFIB \
--output-dir results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain
# Stage 2: LTAF 6-class fine-tune
.venv/bin/python scripts/train_rhythm_from_beats_finetune.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--pretrained results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain/best_classifier.pt \
--window-seconds 10 --max-beats 32 \
--epochs 30 --batch-size 32 --lr 5e-4 \
--use-val-as-train \
--classes NSR AFIB SBR AB SVTA B \
--output-dir results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft
# Eval with TTA-7
.venv/bin/python scripts/eval_with_tta_rfb.py \
--htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
--checkpoint results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft/best_classifier.pt \
--n-views 7 \
--output results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft_tta.json
```
Best val macro F1 0.741 at epoch 23. Test single-view 0.731, TTA-7
0.767. Total training time on a single H100: ~25 min for HTF + ~1 min
for rhythm pretrain + ~20 min for rhythm fine-tune + ~30 s for TTA eval.
## Ablations
| Variant | Test macro F1 |
|---|---:|
| **v2 + TTA-7** | **0.767** |
| v2 single-window | 0.731 |
| v2 + v4 ensemble + TTA-7 | 0.756 |
| v1 (RhythmResNet1D + TTA-16, no beat embedding) | 0.658 |
| v4 (fp32 cache, 6-corpus pretrain w/ mitdb+nsrdb+Icentia 2k) + TTA-7 | 0.711 |
| v3 (fp16 cache regression) + ft | 0.684 |
| v1 + v2 ensemble + TTA-7 | 0.710 |
Notable negative results (didn't help):
- **fp16 feature cache**: 5 pp regression vs live HTF (precision loss
in the 576-d embeddings hurt the rhythm head's discriminative
power). v2 uses live HTF (uncached); fp32 cache (v4) helps with
speed but the bigger pretrain corpus skewed toward NSR.
- **Multi-model ensembles** (v1+v2, v2+v4): mild regressions because
v1 and v4 are weaker baselines.
- **Multi-corpus MAE pretrain on raw signal** (v3 transformer): F1
0.34 β€” beat-level supervision matters more than raw-signal
reconstruction at this corpus size.
## Not for clinical use
Research artifact only. Not FDA-cleared. Not suitable for triage,
diagnosis, or any patient-facing application.
## Citations
```bibtex
@misc{petrutiu2008ltafdb,
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
year = {2008},
howpublished = {PhysioNet},
url = {https://physionet.org/content/ltafdb/}
}
```
CPSC 2021, MIT-BIH AFDB and Icentia11k were also used for the
multi-corpus stages.