Initial release
Browse files- README.md +189 -0
- __pycache__/model.cpython-312.pyc +0 -0
- best_classifier.pt +3 -0
- inference.py +119 -0
- model.py +116 -0
- requirements.txt +3 -0
README.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: pytorch
|
| 3 |
+
tags:
|
| 4 |
+
- ecg
|
| 5 |
+
- arrhythmia
|
| 6 |
+
- rhythm-classification
|
| 7 |
+
- ltaf
|
| 8 |
+
- physionet
|
| 9 |
+
- 1d-resnet
|
| 10 |
+
license: mit
|
| 11 |
+
datasets:
|
| 12 |
+
- physionet/ltafdb
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# LTAF ECG Rhythm Classifier — RhythmResNet1D + TTA
|
| 16 |
+
|
| 17 |
+
A from-scratch 1D-ResNet trained on PhysioNet's
|
| 18 |
+
[Long-Term Atrial Fibrillation (LTAF)](https://physionet.org/content/ltafdb/)
|
| 19 |
+
database for **6-class rhythm classification** on two-lead 128 Hz ECG.
|
| 20 |
+
|
| 21 |
+
| Metric | Single-window | **+ 7-view TTA (recommended)** |
|
| 22 |
+
|---|---:|---:|
|
| 23 |
+
| Test accuracy | 0.636 | **0.684** |
|
| 24 |
+
| Test balanced accuracy | 0.740 | **0.778** |
|
| 25 |
+
| **Test macro F1** | **0.614** | **0.656** |
|
| 26 |
+
|
| 27 |
+
vs. frozen Chronos-2 + MLP baseline on the same 6-class subset:
|
| 28 |
+
test macro F1 = 0.299 — i.e. **+36 pp / 2.2× the F1**.
|
| 29 |
+
|
| 30 |
+
Per-class F1 (TTA-7): NSR 0.76, AFIB 0.62, SBR 0.82, AB 0.77, SVTA 0.15, B 0.82.
|
| 31 |
+
|
| 32 |
+
## Classes
|
| 33 |
+
|
| 34 |
+
| Code | Expansion |
|
| 35 |
+
|------|-----------|
|
| 36 |
+
| NSR | Normal sinus rhythm |
|
| 37 |
+
| AFIB | Atrial fibrillation |
|
| 38 |
+
| SBR | Sinus bradycardia (<60 bpm, sinus origin) |
|
| 39 |
+
| AB | Atrial bigeminy (every other beat is an APC) |
|
| 40 |
+
| SVTA | Supraventricular tachyarrhythmia (≥3 consec SV ectopics @ >100 bpm) |
|
| 41 |
+
| B | Ventricular bigeminy (every other beat is a PVC) |
|
| 42 |
+
|
| 43 |
+
`VT`, `T`, and `IVR` are excluded — their LTAF test supports (31, 26, 1) are too small for stable F1 estimation.
|
| 44 |
+
|
| 45 |
+
## Quickstart
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
pip install torch huggingface_hub numpy
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
import numpy as np
|
| 53 |
+
import torch
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
from model import RhythmResNet1D, RHYTHM_CLASS_NAMES
|
| 56 |
+
|
| 57 |
+
# Download checkpoint + model code from HF
|
| 58 |
+
ckpt = hf_hub_download("rmxjck/ltaf-ecg-rhythm-classifier", "best_classifier.pt")
|
| 59 |
+
model = RhythmResNet1D.load(ckpt, device="cuda")
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
# Input: (B, 2, 1280) — 10 s @ 128 Hz, 2 leads, per-channel z-scored.
|
| 63 |
+
x = torch.randn(1, 2, 1280).cuda() # replace with real ECG
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
logits = model(x)
|
| 66 |
+
pred_idx = logits.argmax(-1).item()
|
| 67 |
+
print(model.class_names[pred_idx])
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
For best results, use the **7-view TTA** wrapper in `inference.py`
|
| 71 |
+
(averages softmax across 7 random window-start offsets — adds ~4 pp F1
|
| 72 |
+
at the cost of 7× inference compute).
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
python inference.py
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Architecture
|
| 79 |
+
|
| 80 |
+
`RhythmResNet1D(num_classes=6, n_channels=2, base_channels=64,
|
| 81 |
+
blocks_per_stage=2)`:
|
| 82 |
+
|
| 83 |
+
- **Stem:** Conv1d(2, 64, k=15, stride=2) → BN → ReLU → MaxPool(2).
|
| 84 |
+
- **4 ResNet stages × 2 basic blocks** (Conv1d k=7, BN, ReLU, Dropout, +skip).
|
| 85 |
+
Channels: 64 → 128 → 256 → 512. Time downsamples 2× at the start of each
|
| 86 |
+
stage past the first.
|
| 87 |
+
- **Head:** AdaptiveAvgPool1d → Linear(512 → 128) → ReLU → Dropout(0.2)
|
| 88 |
+
→ Linear(128 → 6).
|
| 89 |
+
- **Total parameters:** 8,794,246.
|
| 90 |
+
|
| 91 |
+
## Input format
|
| 92 |
+
|
| 93 |
+
- `(B, 2, 1280)` float32
|
| 94 |
+
- 2-lead ECG at **128 Hz** (LTAF leads `ECG1`, `ECG2`)
|
| 95 |
+
- 10 s window
|
| 96 |
+
- Per-channel z-scored: `(x - x.mean(axis=-1)) / x.std(axis=-1)`
|
| 97 |
+
|
| 98 |
+
## Test-time augmentation (TTA)
|
| 99 |
+
|
| 100 |
+
Pass a longer signal slice (≥1280 samples) to `predict_tta()` and it
|
| 101 |
+
samples 7 random 10 s windows, averages the softmax outputs, then
|
| 102 |
+
argmaxes. Why it helps: training uses random window-start sampling
|
| 103 |
+
within each rhythm bout, so the model learns to be invariant to that
|
| 104 |
+
shift. At eval time, taking multiple shifts and averaging cancels the
|
| 105 |
+
position-specific noise. **+4.2 pp test macro F1, no retraining.**
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
# (2, 30*128) signal, 30 s long
|
| 109 |
+
cls, prob, full_probs = predict_tta(model, long_signal, n_views=7, device="cuda")
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Training recipe
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
.venv/bin/python scripts/train_ecg_rhythm_scratch.py \
|
| 116 |
+
--arch resnet1d --window-sizes 10 \
|
| 117 |
+
--epochs 30 --batch-size 64 --lr 5e-4 \
|
| 118 |
+
--base-channels 64 \
|
| 119 |
+
--use-val-as-train \
|
| 120 |
+
--classes NSR AFIB SBR AB SVTA B \
|
| 121 |
+
--output-dir results/ecg_classifier/sweep/c6_resnet1d_w10_e30_wide
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
- Dataset: LTAF train+val combined (75 records). 8 records held out for
|
| 125 |
+
early stopping. Test (9 records, 3,716 windows) untouched.
|
| 126 |
+
- Loss: weighted cross-entropy with sqrt-dampened inverse-frequency
|
| 127 |
+
class weights (cap 10), label smoothing 0.1.
|
| 128 |
+
- Cosine LR schedule from 5e-4 → 0 over 30 epochs. AdamW (wd 1e-4).
|
| 129 |
+
- Best checkpoint by held-out macro F1.
|
| 130 |
+
- Training time on a single H100 80GB: **~6 minutes**.
|
| 131 |
+
|
| 132 |
+
Source repo: `scripts/train_ecg_rhythm_scratch.py` and
|
| 133 |
+
`src/models/ts_llm/ecg_rhythm_scratch.py` in
|
| 134 |
+
[rmxjck/TSLM-Arena](https://github.com/rmxjck/TSLM-Arena).
|
| 135 |
+
|
| 136 |
+
## Test set details
|
| 137 |
+
|
| 138 |
+
LTAF held-out split (deterministic seed 42, record-level): 9 records
|
| 139 |
+
(`100, 104, 105, 11, 200, 32, 48, 49, 68`), 3,716 windows.
|
| 140 |
+
|
| 141 |
+
Confusion matrix (rows = true, cols = pred), with TTA:
|
| 142 |
+
|
| 143 |
+
| | NSR | AFIB | SBR | AB | SVTA | B |
|
| 144 |
+
|-------|----:|-----:|----:|----:|-----:|----:|
|
| 145 |
+
| NSR | 1109 | 286 | 95 | 114 | 185 | 35 |
|
| 146 |
+
| AFIB | 189 | 628 | 25 | 29 | 294 | 14 |
|
| 147 |
+
| SBR | 26 | 0 | 279 | 0 | 0 | 0 |
|
| 148 |
+
| AB | 9 | 13 | 0 | 225 | 3 | 0 |
|
| 149 |
+
| SVTA | 9 | 14 | 0 | 3 | 34 | 0 |
|
| 150 |
+
| B | 4 | 0 | 0 | 1 | 3 | 90 |
|
| 151 |
+
|
| 152 |
+
Per-class supports: NSR 1824, AFIB 1179, SBR 305, AB 250, SVTA 60, B 98.
|
| 153 |
+
|
| 154 |
+
## What was tried and didn't help
|
| 155 |
+
|
| 156 |
+
This model was the best of 30+ experiments. What did *not* improve over
|
| 157 |
+
this baseline:
|
| 158 |
+
|
| 159 |
+
- HRV side-channel input (8-dim RR-derived features fused with CNN trunk):
|
| 160 |
+
hurts F1 by 3-8 pp because the CNN already extracts equivalent
|
| 161 |
+
information from raw QRS timing.
|
| 162 |
+
- Cross-corpus augmentation (MIT-BIH AFDB added to training): hurts
|
| 163 |
+
AFIB F1 by 14 pp because AFDB's clean AFIB blocks bias the model
|
| 164 |
+
toward over-calling AFIB on LTAF's paroxysmal transitions.
|
| 165 |
+
- Wider models (96-channel, 12 M params): overfits.
|
| 166 |
+
- Longer training (50 epochs): overfits.
|
| 167 |
+
- Multi-model soft-voting ensembles: members make correlated errors.
|
| 168 |
+
- Focal loss: matches CE within noise.
|
| 169 |
+
- Multi-scale training (5 / 10 / 30 s windows): underperforms 10 s alone.
|
| 170 |
+
- Bigger external models (torchecg ResNet-50 51.9 M, Stanford 27 M):
|
| 171 |
+
underperform a 2.2 M home-rolled ResNet1D at 12 epochs.
|
| 172 |
+
|
| 173 |
+
## Not for clinical use
|
| 174 |
+
|
| 175 |
+
Research artifact only. **Not FDA-cleared.** Not suitable for triage,
|
| 176 |
+
diagnosis, or any patient-facing application. Uses the LTAF benchmark
|
| 177 |
+
which has known label noise from its original PhysioNet curation.
|
| 178 |
+
|
| 179 |
+
## Citation
|
| 180 |
+
|
| 181 |
+
```bibtex
|
| 182 |
+
@misc{petrutiu2008ltafdb,
|
| 183 |
+
title = {Abrupt Changes in Fibrillatory Wave Characteristics at the Termination of Paroxysmal Atrial Fibrillation in Humans},
|
| 184 |
+
author = {Petrutiu, Simona and Sahakian, Alan V. and Swiryn, Steven},
|
| 185 |
+
year = {2008},
|
| 186 |
+
howpublished = {PhysioNet},
|
| 187 |
+
url = {https://physionet.org/content/ltafdb/}
|
| 188 |
+
}
|
| 189 |
+
```
|
__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
best_classifier.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e52d7453256a051cbcc5516a1462d9869a5baf8e2e5caeb2d2a0e5b69fa3e961
|
| 3 |
+
size 35256343
|
inference.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# SPDX-License-Identifier: MIT
|
| 3 |
+
"""Inference example for the LTAF ECG rhythm classifier.
|
| 4 |
+
|
| 5 |
+
Two modes:
|
| 6 |
+
- Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor.
|
| 7 |
+
- TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the
|
| 8 |
+
function will pull 7 random 10 s windows from it and soft-vote.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
.venv/bin/python inference.py
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
|
| 24 |
+
from model import RHYTHM_CLASS_NAMES, RhythmResNet1D
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
WINDOW_SECONDS = 10
|
| 28 |
+
SOURCE_HZ = 128
|
| 29 |
+
WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ # 1280
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_model(device: str = "cpu") -> RhythmResNet1D:
|
| 33 |
+
"""Download the checkpoint from HF and load it."""
|
| 34 |
+
ckpt_path = hf_hub_download(
|
| 35 |
+
"rmxjck/ltaf-ecg-rhythm-classifier",
|
| 36 |
+
"best_classifier.pt",
|
| 37 |
+
)
|
| 38 |
+
return RhythmResNet1D.load(ckpt_path, device=device)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def zscore(window: np.ndarray) -> np.ndarray:
|
| 42 |
+
"""Per-channel z-score a (C, L) array."""
|
| 43 |
+
mean = window.mean(axis=-1, keepdims=True)
|
| 44 |
+
std = window.std(axis=-1, keepdims=True)
|
| 45 |
+
return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def predict_single(
|
| 49 |
+
model: RhythmResNet1D,
|
| 50 |
+
window: np.ndarray,
|
| 51 |
+
device: str = "cpu",
|
| 52 |
+
) -> Tuple[str, float]:
|
| 53 |
+
"""Predict on one (2, 1280) z-scored window. Returns (class_name, prob)."""
|
| 54 |
+
if window.shape != (2, WINDOW_SAMPLES):
|
| 55 |
+
raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}")
|
| 56 |
+
x = torch.from_numpy(window).float().unsqueeze(0).to(device)
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
probs = F.softmax(model(x), dim=-1)[0]
|
| 59 |
+
idx = int(probs.argmax().item())
|
| 60 |
+
return model.class_names[idx], float(probs[idx].item())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def predict_tta(
|
| 64 |
+
model: RhythmResNet1D,
|
| 65 |
+
long_signal: np.ndarray,
|
| 66 |
+
n_views: int = 7,
|
| 67 |
+
device: str = "cpu",
|
| 68 |
+
seed: int = 42,
|
| 69 |
+
) -> Tuple[str, float, np.ndarray]:
|
| 70 |
+
"""TTA-soft-voting prediction over a longer (2, L) signal.
|
| 71 |
+
|
| 72 |
+
Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280),
|
| 73 |
+
z-scores each independently, runs them through the model, and averages
|
| 74 |
+
the softmax probabilities.
|
| 75 |
+
|
| 76 |
+
Returns (class_name, prob, full_probs) where full_probs is shape (6,).
|
| 77 |
+
"""
|
| 78 |
+
n_ch, n_samples = long_signal.shape
|
| 79 |
+
if n_ch != 2:
|
| 80 |
+
raise ValueError(f"Expected 2-channel signal, got {n_ch}")
|
| 81 |
+
if n_samples < WINDOW_SAMPLES:
|
| 82 |
+
raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}")
|
| 83 |
+
rng = np.random.default_rng(seed)
|
| 84 |
+
starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views)
|
| 85 |
+
accum = torch.zeros(model.num_classes, device=device)
|
| 86 |
+
for s in starts:
|
| 87 |
+
window = zscore(long_signal[:, s:s + WINDOW_SAMPLES])
|
| 88 |
+
x = torch.from_numpy(window).float().unsqueeze(0).to(device)
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
probs = F.softmax(model(x), dim=-1)[0]
|
| 91 |
+
accum += probs
|
| 92 |
+
probs_avg = accum / n_views
|
| 93 |
+
idx = int(probs_avg.argmax().item())
|
| 94 |
+
return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def demo():
|
| 98 |
+
print("Loading model from HF...")
|
| 99 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 100 |
+
model = load_model(device)
|
| 101 |
+
print(f"Loaded {model.__class__.__name__} on {device}")
|
| 102 |
+
print(f"Classes: {model.class_names}")
|
| 103 |
+
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
|
| 104 |
+
|
| 105 |
+
# Synthetic example: random noise (will get garbage prediction).
|
| 106 |
+
print("\n--- single-window demo (random input) ---")
|
| 107 |
+
fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32))
|
| 108 |
+
cls, prob = predict_single(model, fake_window, device=device)
|
| 109 |
+
print(f"prediction: {cls} ({prob:.1%})")
|
| 110 |
+
|
| 111 |
+
print("\n--- TTA-7 demo (random 30 s input) ---")
|
| 112 |
+
fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32)
|
| 113 |
+
cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device)
|
| 114 |
+
print(f"prediction: {cls} ({prob:.1%})")
|
| 115 |
+
print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
demo()
|
model.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: MIT
|
| 2 |
+
"""Self-contained RhythmResNet1D for LTAF rhythm classification.
|
| 3 |
+
|
| 4 |
+
Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_rhythm_scratch.py)
|
| 5 |
+
so the model can be loaded with no external project imports.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
RHYTHM_CLASS_NAMES = ["NSR", "AFIB", "SBR", "AB", "SVTA", "B"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _BasicBlock1D(nn.Module):
|
| 22 |
+
"""Two-conv residual block with optional stride-2 downsample."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, in_c: int, out_c: int, kernel: int = 7, stride: int = 1,
|
| 25 |
+
dropout: float = 0.1):
|
| 26 |
+
super().__init__()
|
| 27 |
+
pad = kernel // 2
|
| 28 |
+
self.conv1 = nn.Conv1d(in_c, out_c, kernel_size=kernel, stride=stride,
|
| 29 |
+
padding=pad, bias=False)
|
| 30 |
+
self.bn1 = nn.BatchNorm1d(out_c)
|
| 31 |
+
self.conv2 = nn.Conv1d(out_c, out_c, kernel_size=kernel, stride=1,
|
| 32 |
+
padding=pad, bias=False)
|
| 33 |
+
self.bn2 = nn.BatchNorm1d(out_c)
|
| 34 |
+
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
| 35 |
+
if stride != 1 or in_c != out_c:
|
| 36 |
+
self.proj = nn.Sequential(
|
| 37 |
+
nn.Conv1d(in_c, out_c, kernel_size=1, stride=stride, bias=False),
|
| 38 |
+
nn.BatchNorm1d(out_c),
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
self.proj = nn.Identity()
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
identity = self.proj(x)
|
| 45 |
+
h = F.relu(self.bn1(self.conv1(x)), inplace=True)
|
| 46 |
+
h = self.drop(h)
|
| 47 |
+
h = self.bn2(self.conv2(h))
|
| 48 |
+
return F.relu(h + identity, inplace=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class RhythmResNet1D(nn.Module):
|
| 52 |
+
"""1D ResNet — stem + 4 stages, each stage halves time and doubles channels."""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
num_classes: int = 6,
|
| 57 |
+
class_names: List[str] = RHYTHM_CLASS_NAMES,
|
| 58 |
+
n_channels: int = 2,
|
| 59 |
+
base_channels: int = 64,
|
| 60 |
+
blocks_per_stage: int = 2,
|
| 61 |
+
stem_kernel: int = 15,
|
| 62 |
+
block_kernel: int = 7,
|
| 63 |
+
dropout: float = 0.2,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
assert len(class_names) == num_classes
|
| 67 |
+
self.num_classes = num_classes
|
| 68 |
+
self.class_names = list(class_names)
|
| 69 |
+
self.n_channels = n_channels
|
| 70 |
+
self.base_channels = base_channels
|
| 71 |
+
self.blocks_per_stage = blocks_per_stage
|
| 72 |
+
|
| 73 |
+
self.stem = nn.Sequential(
|
| 74 |
+
nn.Conv1d(n_channels, base_channels, kernel_size=stem_kernel,
|
| 75 |
+
stride=2, padding=stem_kernel // 2, bias=False),
|
| 76 |
+
nn.BatchNorm1d(base_channels),
|
| 77 |
+
nn.ReLU(inplace=True),
|
| 78 |
+
nn.MaxPool1d(2),
|
| 79 |
+
)
|
| 80 |
+
stages = []
|
| 81 |
+
in_c = base_channels
|
| 82 |
+
out_c = base_channels
|
| 83 |
+
for s in range(4):
|
| 84 |
+
for b in range(blocks_per_stage):
|
| 85 |
+
stride = 2 if (b == 0 and s > 0) else 1
|
| 86 |
+
stages.append(_BasicBlock1D(in_c, out_c, kernel=block_kernel,
|
| 87 |
+
stride=stride, dropout=dropout))
|
| 88 |
+
in_c = out_c
|
| 89 |
+
out_c = min(out_c * 2, 512)
|
| 90 |
+
self.stages = nn.Sequential(*stages)
|
| 91 |
+
self.pool = nn.AdaptiveAvgPool1d(1)
|
| 92 |
+
self.head = nn.Sequential(
|
| 93 |
+
nn.Linear(in_c, 128),
|
| 94 |
+
nn.ReLU(inplace=True),
|
| 95 |
+
nn.Dropout(dropout),
|
| 96 |
+
nn.Linear(128, num_classes),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
h = self.stem(x)
|
| 101 |
+
h = self.stages(h)
|
| 102 |
+
feat = self.pool(h).squeeze(-1)
|
| 103 |
+
return self.head(feat)
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def load(cls, path: str | Path, device: str = "cpu") -> "RhythmResNet1D":
|
| 107 |
+
ckpt = torch.load(path, map_location=device, weights_only=False)
|
| 108 |
+
model = cls(
|
| 109 |
+
num_classes=ckpt["num_classes"], class_names=ckpt["class_names"],
|
| 110 |
+
n_channels=ckpt["n_channels"],
|
| 111 |
+
base_channels=ckpt.get("base_channels", 64),
|
| 112 |
+
blocks_per_stage=ckpt.get("blocks_per_stage", 2),
|
| 113 |
+
)
|
| 114 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 115 |
+
model.to(device).eval()
|
| 116 |
+
return model
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
huggingface_hub>=0.20
|
| 3 |
+
numpy>=1.24
|