File size: 9,187 Bytes
4d0d75d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
---
library_name: pytorch
tags:
  - ecg
  - classification
  - arrhythmia
  - ltaf
  - physionet
  - rhythm
license: mit
---

# LTAF ECG Rhythm Classifier β€” RhythmFromBeats v2

Beat-embedding rhythm classifier for LTAF. Two-stage pipeline:

1. **Frozen HTF beat embedder** (`htf_embedder.pt`, 1.14 M params) trained
   on **3 corpora** (LTAF + CPSC2021 + AFDB, 1 056 K beats; N 0.977,
   A 0.935, V 0.907 on LTAF held-out test).
2. **Beat-sequence Transformer head** (`rhythm_classifier.pt`, 893 K params)
   trained on the per-beat (576-d) features extracted by the embedder
   for each rhythm bout. Pretrained binary {NSR, AFIB} on 4 corpora
   (LTAF + CPSC2021 + AFDB + Icentia 200), then fine-tuned to LTAF
   6-class (NSR / AFIB / SBR / AB / SVTA / B).

7-view test-time augmentation (TTA-7) hits **macro F1 = 0.767** on the
LTAF held-out 9-record test set β€” **+11 pp over the previous v1
RhythmResNet1D + TTA** (0.658) and **+47 pp over the Chronos-2 frozen
baseline** (0.299).

## Why beat embeddings?

The previous v1 (a from-scratch 1D-ResNet on raw 10-s windows) was
limited by patient-distribution shift in LTAF β€” beat morphology
generalizes well across patients (HTF F1 = 0.94), but rhythm
classification on raw signal struggled to factor that out.

The v2 pipeline decouples it: the embedder extracts patient-robust
beat features, and the rhythm head only learns *sequence patterns*
(irregular RR for AFIB, N-V-N-V for B, fast regular for SVTA, etc.) on
top. This collapses the patient-shift bottleneck. **SVTA F1 went from
0.15 β†’ 0.71 (+47 pp on the worst class)** between v1 and v2 + TTA-7.

## Classes

| Code | Expansion |
|------|-----------|
| NSR  | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR  | Sinus bradycardia (<60 bpm, sinus origin) |
| AB   | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β‰₯3 consec SV ectopics @ >100 bpm) |
| B    | Ventricular bigeminy (every other beat is a PVC) |

`VT` (31 test windows), `T` (26), `IVR` (1) were dropped β€” supports too
small for stable F1 estimation.

## Test results β€” LTAF held-out (9 records, 3 716 windows)

### Single-window (no TTA)

| Metric | Value |
|---|---:|
| Accuracy | 0.713 |
| Balanced accuracy | 0.805 |
| Macro F1 | **0.731** |

### TTA-7 (recommended)

| Metric | Value |
|---|---:|
| Accuracy | 0.760 |
| Balanced accuracy | 0.894 |
| **Macro F1** | **0.767** |

Per-class F1 with TTA-7:

| Class | F1 | v1+TTA-7 | Ξ” |
|---|---:|---:|---:|
| NSR  | 0.754 | 0.756 | -0.2 |
| AFIB | 0.759 | 0.619 | **+14.0** |
| SBR  | 0.759 | 0.821 | -6.2 |
| AB   | 0.782 | 0.769 | +1.3 |
| **SVTA** | **0.715** | 0.148 | **+56.7** |
| B    | 0.835 | 0.821 | +1.4 |
| Macro | **0.767** | 0.658 | **+10.9** |

## Inference

```python
import torch
from src.models.ts_llm.ecg_beat_htf import EcgBeatHTFClassifier
from src.models.ts_llm.rhythm_from_beats import RhythmFromBeats, htf_fused_features

device = "cuda"
htf = EcgBeatHTFClassifier.load("htf_embedder.pt", device=device).eval()
model = RhythmFromBeats.load("rhythm_classifier.pt", device=device).eval()

# For each rhythm bout (10 s @ 128 Hz, 2 leads):
#   1. find R-peak indices in the bout
#   2. for each R-peak, extract a 2 s window (+/- 128 samples) and
#      build (rr_history, label_history) from preceding K=5 beats
#   3. run HTF on the per-beat batch -> (n_beats, 576) features
#   4. concat (rr-to-prev, normalized seconds) -> (n_beats, 577)
#   5. pass to rhythm head with valid_mask
beat_signals  = ...  # (B, T, 2, 256) per-beat windows, z-scored
rr_history    = ...  # (B, T, 5)      RR intervals to preceding 5 beats
label_history = ...  # (B, T, 5)      preceding 5 beat labels (-1 if N/A)
rr_extra      = ...  # (B, T, 1)      RR-to-prev in this bout
valid_mask    = ...  # (B, T)         True at valid beat positions

with torch.no_grad():
    flat_sig = beat_signals.view(-1, 2, 256)
    flat_rr  = rr_history.view(-1, 5)
    flat_lab = label_history.view(-1, 5)
    feats    = htf_fused_features(htf, flat_sig, flat_rr, flat_lab)  # (B*T, 576)
    feats    = feats.view(beat_signals.size(0), beat_signals.size(1), -1)
    logits   = model(feats, rr_extra, valid_mask)  # (B, 6)
    pred     = logits.argmax(-1)  # 0=NSR, 1=AFIB, 2=SBR, 3=AB, 4=SVTA, 5=B
```

For best results, use 7-view TTA at inference β€” see
[`scripts/eval_with_tta_rfb.py`](https://github.com/) which
re-rolls the within-bout window offset 7 times and averages softmax.

## Architecture

### HTF embedder (`htf_embedder.pt`, 1 143 939 params)

Three parallel streams concatenated into a 576-dim "fused" feature
before a 3-class N/A/V head (we discard the head and use the fused
feature):

- **Time stream**  β€” 1D-CNN trunk on raw (B, 2, 256) 2-s R-peak window.
  5 conv blocks, base width 32 β†’ 256, AdaptiveAvgPool1d β†’ 256-d.
- **Frequency stream** β€” same trunk on log-magnitude rFFT (B, 2, 129) β†’
  256-d.
- **History stream** β€” MLP on (5 RR intervals, 5Γ—3 one-hot of
  preceding-beat labels) β†’ 64-d.

Trained on **LTAF + CPSC2021 + AFDB** beats (701 561 N + 184 130 A +
169 946 V), AAMI EC57 mapping. Test F1 on LTAF: 0.939.

### Rhythm head (`rhythm_classifier.pt`, 892 806 params)

4-layer Transformer encoder over the per-beat (576+1 = 577)-d
sequence:

- **Input projection** Linear(577, 128).
- **Positional embedding** learned, max 64 beats.
- **4 Γ— _TransformerBlock**: 4-head MHA + 4Γ—128-d MLP with GELU.
- **Masked-mean pool** over valid beat positions.
- **MLP head** Linear(128, 128) β†’ GELU β†’ Dropout(0.1) β†’ Linear(128, 6).

Pretrained binary {NSR, AFIB} on 4 corpora (LTAF train + CPSC2021 +
AFDB + Icentia 200, 8 epochs), then fine-tuned 30 epochs on LTAF
6-class (train+val merged, val-stratified checkpointing on 8 held-back
records).

## Training

Reproducer in [rmxjck/TSLM-Arena](https://github.com/) (commit
forthcoming):

```bash
# Stage 0: train multi-corpus HTF beats classifier
.venv/bin/python scripts/train_ecg_beat_htf_multicorpus.py \
    --corpora ltaf cpsc2021 afdb \
    --epochs 12 --batch-size 256 --lr 1e-3 \
    --output-dir results/ecg_classifier/beats_htf_multicorpus

# Stage 1: binary {NSR, AFIB} pretrain on 4 corpora
.venv/bin/python scripts/train_rhythm_from_beats_multicorpus.py \
    --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
    --corpora ltaf cpsc2021 afdb icentia \
    --icentia-records 200 --max-bouts-per-record 100 \
    --epochs 8 --batch-size 32 --lr 5e-4 \
    --classes NSR AFIB \
    --output-dir results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain

# Stage 2: LTAF 6-class fine-tune
.venv/bin/python scripts/train_rhythm_from_beats_finetune.py \
    --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
    --pretrained results/ecg_classifier/sweep/c2_rhythm_from_beats_pretrain/best_classifier.pt \
    --window-seconds 10 --max-beats 32 \
    --epochs 30 --batch-size 32 --lr 5e-4 \
    --use-val-as-train \
    --classes NSR AFIB SBR AB SVTA B \
    --output-dir results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft

# Eval with TTA-7
.venv/bin/python scripts/eval_with_tta_rfb.py \
    --htf-checkpoint results/ecg_classifier/beats_htf_multicorpus/best_classifier.pt \
    --checkpoint results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft/best_classifier.pt \
    --n-views 7 \
    --output results/ecg_classifier/sweep/c6_rhythm_from_beats_v2_ft_tta.json
```

Best val macro F1 0.741 at epoch 23. Test single-view 0.731, TTA-7
0.767. Total training time on a single H100: ~25 min for HTF + ~1 min
for rhythm pretrain + ~20 min for rhythm fine-tune + ~30 s for TTA eval.

## Ablations

| Variant | Test macro F1 |
|---|---:|
| **v2 + TTA-7** | **0.767** |
| v2 single-window | 0.731 |
| v2 + v4 ensemble + TTA-7 | 0.756 |
| v1 (RhythmResNet1D + TTA-16, no beat embedding) | 0.658 |
| v4 (fp32 cache, 6-corpus pretrain w/ mitdb+nsrdb+Icentia 2k) + TTA-7 | 0.711 |
| v3 (fp16 cache regression) + ft | 0.684 |
| v1 + v2 ensemble + TTA-7 | 0.710 |

Notable negative results (didn't help):

- **fp16 feature cache**: 5 pp regression vs live HTF (precision loss
  in the 576-d embeddings hurt the rhythm head's discriminative
  power). v2 uses live HTF (uncached); fp32 cache (v4) helps with
  speed but the bigger pretrain corpus skewed toward NSR.
- **Multi-model ensembles** (v1+v2, v2+v4): mild regressions because
  v1 and v4 are weaker baselines.
- **Multi-corpus MAE pretrain on raw signal** (v3 transformer): F1
  0.34 β€” beat-level supervision matters more than raw-signal
  reconstruction at this corpus size.

## Not for clinical use

Research artifact only. Not FDA-cleared. Not suitable for triage,
diagnosis, or any patient-facing application.

## Citations

```bibtex
@misc{petrutiu2008ltafdb,
  title         = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
  author        = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
  year          = {2008},
  howpublished  = {PhysioNet},
  url           = {https://physionet.org/content/ltafdb/}
}
```

CPSC 2021, MIT-BIH AFDB and Icentia11k were also used for the
multi-corpus stages.