File size: 6,506 Bytes
464d595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
---
library_name: pytorch
tags:
- ecg
- arrhythmia
- rhythm-classification
- ltaf
- physionet
- 1d-resnet
license: mit
datasets:
- physionet/ltafdb
---

# LTAF ECG Rhythm Classifier β€” RhythmResNet1D + TTA

A from-scratch 1D-ResNet trained on PhysioNet's
[Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/)
database for **6-class rhythm classification** on two-lead 128 Hz ECG.

| Metric | Single-window | **+ 7-view TTA (recommended)** |
|---|---:|---:|
| Test accuracy | 0.636 | **0.684** |
| Test balanced accuracy | 0.740 | **0.778** |
| **Test macro F1** | **0.614** | **0.656** |

vs. frozen Chronos-2 + MLP baseline on the same 6-class subset:
test macro F1 = 0.299 β€” i.e. **+36 pp / 2.2Γ— the F1**.

Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.

## Classes

| Code | Expansion |
|------|-----------|
| NSR  | Normal sinus rhythm |
| AFIB | Atrial fibrillation |
| SBR  | Sinus bradycardia (<60 bpm, sinus origin) |
| AB   | Atrial bigeminy (every other beat is an APC) |
| SVTA | Supraventricular tachyarrhythmia (β‰₯3 consec SV ectopics @ >100 bpm) |
| B    | Ventricular bigeminy (every other beat is a PVC) |

`VT`, `T`, and `IVR` are excluded β€” their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.

## Quickstart

```bash
pip install torch huggingface_hub numpy
```

```python
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES

# Download checkpoint + model code from HF
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
model = RhythmResNet1D.load(ckpt, device="cuda")
model.eval()

# Input: (B, 2, 1280) β€” 10 s @ 128 Hz, 2 leads, per-channel z-scored.
x = torch.randn(1, 2, 1280).cuda()  # replace with real ECG
with torch.no_grad():
    logits = model(x)
    pred_idx = logits.argmax(-1).item()
print(model.class_names[pred_idx])
```

For best results, use the **7-view TTA** wrapper in `inference.py`
(averages softmax across 7 random window-start offsets β€” adds ~4 pp F1
at the cost of 7Γ— inference compute).

```bash
python inference.py
```

## Architecture

`RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64,
blocks_per_stage=2)`:

- **Stem:** Conv1d(2, 64, k=15, stride=2) β†’ BN β†’ ReLU β†’ MaxPool(2).
- **4 ResNet stages Γ— 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip).
  Channels: 64 β†’ 128 β†’ 256 β†’ 512. Time downsamples 2Γ— at the start of each
  stage past the first.
- **Head:** AdaptiveAvgPool1d β†’ Linear(512 β†’ 128) β†’ ReLU β†’ Dropout(0.2)
  β†’ Linear(128 β†’ 6).
- **Total parameters:** 8,794,246.

## Input format

- `(B, 2, 1280)` float32
- 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`)
- 10 s window
- Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)`

## Test-time augmentation (TTA)

Pass a longer signal slice (β‰₯1280 samples) to `predict_tta()` and it
samples 7 random 10 s windows, averages the softmax outputs, then
argmaxes. Why it helps: training uses random window-start sampling
within each rhythm bout, so the model learns to be invariant to that
shift. At eval time, taking multiple shifts and averaging cancels the
position-specific noise. **+4.2 pp test macro F1, no retraining.**

```python
# (2, 30*128) signal, 30 s long
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
```

## Training recipe

```bash
.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
    --arch resnet1d --window-sizes 10 \
    --epochs 30 --batch-size 64 --lr 5e-4 \
    --base-channels 64 \
    --use-val-as-train \
    --classes NSR AFIB SBR AB SVTA B \
    --output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
```

- Dataset: LTAF train+val combined (75 records). 8 records held out for
  early stopping. Test (9 records, 3,716 windows) untouched.
- Loss: weighted cross-entropy with sqrt-dampened inverse-frequency
  class weights (cap 10), label smoothing 0.1.
- Cosine LR schedule from 5e-4 β†’ 0 over 30 epochs. AdamW (wd 1e-4).
- Best checkpoint by held-out macro F1.
- Training time on a single H100 80GB: **~6 minutes**.

Source repo: `scripts/train_ecg_rhythm_scratch.py` and
`src/models/ts_llm/ecg_rhythm_scratch.py` in
[rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena).

## Test set details

LTAF held-out split (deterministic seed 42, record-level): 9 records
(`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows.

Confusion matrix (rows = true, cols = pred), with TTA:

|       | NSR | AFIB | SBR | AB  | SVTA | B   |
|-------|----:|-----:|----:|----:|-----:|----:|
| NSR   | 1109 | 286 | 95  | 114 | 185  | 35  |
| AFIB  | 189  | 628 | 25  | 29  | 294  | 14  |
| SBR   | 26   | 0   | 279 | 0   | 0    | 0   |
| AB    | 9    | 13  | 0   | 225 | 3    | 0   |
| SVTA  | 9    | 14  | 0   | 3   | 34   | 0   |
| B     | 4    | 0   | 0   | 1   | 3    | 90  |

Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.

## What was tried and didn't help

This model was the best of 30+ experiments. What did *not* improve over
this baseline:

- HRV side-channel input (8-dim RR-derived features fused with CNN trunk):
  hurts F1 by 3-8 pp because the CNN already extracts equivalent
  information from raw QRS timing.
- Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts
  AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model
  toward over-calling AFIB on LTAF's paroxysmal transitions.
- Wider models (96-channel, 12 M params): overfits.
- Longer training (50 epochs): overfits.
- Multi-model soft-voting ensembles: members make correlated errors.
- Focal loss: matches CE within noise.
- Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
- Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M):
  underperform a 2.2 M home-rolled ResNet1D at 12 epochs.

## Not for clinical use

Research artifact only. **Not FDA-cleared.** Not suitable for triage,
diagnosis, or any patient-facing application. Uses the LTAF benchmark
which has known label noise from its original PhysioNet curation.

## Citation

```bibtex
@misc{petrutiu2008ltafdb,
  title         = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
  author        = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
  year          = {2008},
  howpublished  = {PhysioNet},
  url           = {https://physionet.org/content/ltafdb/}
}
```