File size: 4,181 Bytes
464d595 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | #!/usr/bin/env python3
# SPDX-License-Identifier: MIT
"""Inference example for the LTAF ECG rhythm classifier.
Two modes:
- Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor.
- TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the
function will pull 7 random 10 s windows from it and soft-vote.
Usage:
.venv/bin/python inference.py
"""
from __future__ import annotations
from pathlib import Path
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from model import RHYTHM_CLASS_NAMES, RhythmResNet1D
WINDOW_SECONDS = 10
SOURCE_HZ = 128
WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ # 1280
def load_model(device: str = "cpu") -> RhythmResNet1D:
"""Download the checkpoint from HF and load it."""
ckpt_path = hf_hub_download(
"rmxjck/ltaf-ecg-rhythm-classifier",
"best_classifier.pt",
)
return RhythmResNet1D.load(ckpt_path, device=device)
def zscore(window: np.ndarray) -> np.ndarray:
"""Per-channel z-score a (C, L) array."""
mean = window.mean(axis=-1, keepdims=True)
std = window.std(axis=-1, keepdims=True)
return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False)
def predict_single(
model: RhythmResNet1D,
window: np.ndarray,
device: str = "cpu",
) -> Tuple[str, float]:
"""Predict on one (2, 1280) z-scored window. Returns (class_name, prob)."""
if window.shape != (2, WINDOW_SAMPLES):
raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}")
x = torch.from_numpy(window).float().unsqueeze(0).to(device)
with torch.no_grad():
probs = F.softmax(model(x), dim=-1)[0]
idx = int(probs.argmax().item())
return model.class_names[idx], float(probs[idx].item())
def predict_tta(
model: RhythmResNet1D,
long_signal: np.ndarray,
n_views: int = 7,
device: str = "cpu",
seed: int = 42,
) -> Tuple[str, float, np.ndarray]:
"""TTA-soft-voting prediction over a longer (2, L) signal.
Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280),
z-scores each independently, runs them through the model, and averages
the softmax probabilities.
Returns (class_name, prob, full_probs) where full_probs is shape (6,).
"""
n_ch, n_samples = long_signal.shape
if n_ch != 2:
raise ValueError(f"Expected 2-channel signal, got {n_ch}")
if n_samples < WINDOW_SAMPLES:
raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}")
rng = np.random.default_rng(seed)
starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views)
accum = torch.zeros(model.num_classes, device=device)
for s in starts:
window = zscore(long_signal[:, s:s + WINDOW_SAMPLES])
x = torch.from_numpy(window).float().unsqueeze(0).to(device)
with torch.no_grad():
probs = F.softmax(model(x), dim=-1)[0]
accum += probs
probs_avg = accum / n_views
idx = int(probs_avg.argmax().item())
return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy()
def demo():
print("Loading model from HF...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model(device)
print(f"Loaded {model.__class__.__name__} on {device}")
print(f"Classes: {model.class_names}")
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")
# Synthetic example: random noise (will get garbage prediction).
print("\n--- single-window demo (random input) ---")
fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32))
cls, prob = predict_single(model, fake_window, device=device)
print(f"prediction: {cls} ({prob:.1%})")
print("\n--- TTA-7 demo (random 30 s input) ---")
fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32)
cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device)
print(f"prediction: {cls} ({prob:.1%})")
print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}")
if __name__ == "__main__":
demo()
|