File size: 2,122 Bytes
29d1c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
---
library_name: pytorch
tags:
- ecg
- classification
- chronos-2
- ltaf
- arrhythmia
license: mit
---

# LTAF ECG Beat Classifier (N / A / V)

Frozen **Chronos-2** (`amazon/chronos-2`) multivariate encoder + MLP head,
trained on the PhysioNet Long-Term Atrial Fibrillation (LTAF) database
for per-beat classification.

## Classes

| Code | Expansion |
|------|-----------|
| N | Normal sinus-origin beat |
| A | Atrial premature contraction (APC / PAC / SVE) |
| V | Ventricular premature contraction (PVC / VE) |

`Q` (unclassifiable / paced, ~89 / 9 M in the LTAF subset) is dropped.

## Input

- `(B, 2, 256)` — 2-lead ECG at **128 Hz**, 2-second window **centered on the R-peak sample**
- Per-channel z-scored
- LTAF leads: `ECG1`, `ECG2`

## Checkpoint details

| Field | Value |
|---|---|
| `num_classes` | 3 |
| `class_names` | `["N", "A", "V"]` |
| `window_samples` | 256 (2 s @ 128 Hz) |
| `n_channels` | 2 |
| `chronos_model_id` | `amazon/chronos-2` |
| `freeze_encoder` | `true` (only the head's 395,267 params were trained) |
| Head | 2-layer MLP: `Linear(1024, 512) → ReLU → Dropout(0.3) → Linear(512, 3)` |

## Usage

```python
import torch
from huggingface_hub import hf_hub_download
from src.models.ts_llm.ecg_classifier import EcgRhythmClassifier

path = hf_hub_download("rmxjck/ltaf-ecg-beats-classifier", "best_classifier.pt")
model = EcgRhythmClassifier.load(path, device="cuda")

# x: (B, 2, 256) float32 at 128 Hz, z-scored, centered on R-peak
logits = model(x)
pred = logits.argmax(-1)  # 0=N, 1=A, 2=V
```

## Training

Produced by `scripts/train_ecg_classifier.py` in
[rmxjck/TSLM-Arena](https://github.com/) on the LTAF-Haystack split
(67 train / 8 val / 9 test records, deterministic seed 42). N beats are
subsampled per epoch to `negative_k × n_nonN` (default 2.0) to balance
the 97 % N / 1.7 % A / 1.5 % V class distribution.

```bash
.venv/bin/python3 scripts/train_ecg_classifier.py \
    --label-class beats --epochs 30 --batch-size 128
```

## Not for clinical use

Research artifact only. Not FDA-cleared. Not suitable for triage,
diagnosis, or any patient-facing application.