| |
| |
| """Inference example for the LTAF ECG rhythm classifier. |
| |
| Two modes: |
| - Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor. |
| - TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the |
| function will pull 7 random 10 s windows from it and soft-vote. |
| |
| Usage: |
| .venv/bin/python inference.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
|
|
| from model import RHYTHM_CLASS_NAMES, RhythmResNet1D |
|
|
|
|
| WINDOW_SECONDS = 10 |
| SOURCE_HZ = 128 |
| WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ |
|
|
|
|
| def load_model(device: str = "cpu") -> RhythmResNet1D: |
| """Download the checkpoint from HF and load it.""" |
| ckpt_path = hf_hub_download( |
| "rmxjck/ltaf-ecg-rhythm-classifier", |
| "best_classifier.pt", |
| ) |
| return RhythmResNet1D.load(ckpt_path, device=device) |
|
|
|
|
| def zscore(window: np.ndarray) -> np.ndarray: |
| """Per-channel z-score a (C, L) array.""" |
| mean = window.mean(axis=-1, keepdims=True) |
| std = window.std(axis=-1, keepdims=True) |
| return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False) |
|
|
|
|
| def predict_single( |
| model: RhythmResNet1D, |
| window: np.ndarray, |
| device: str = "cpu", |
| ) -> Tuple[str, float]: |
| """Predict on one (2, 1280) z-scored window. Returns (class_name, prob).""" |
| if window.shape != (2, WINDOW_SAMPLES): |
| raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}") |
| x = torch.from_numpy(window).float().unsqueeze(0).to(device) |
| with torch.no_grad(): |
| probs = F.softmax(model(x), dim=-1)[0] |
| idx = int(probs.argmax().item()) |
| return model.class_names[idx], float(probs[idx].item()) |
|
|
|
|
| def predict_tta( |
| model: RhythmResNet1D, |
| long_signal: np.ndarray, |
| n_views: int = 7, |
| device: str = "cpu", |
| seed: int = 42, |
| ) -> Tuple[str, float, np.ndarray]: |
| """TTA-soft-voting prediction over a longer (2, L) signal. |
| |
| Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280), |
| z-scores each independently, runs them through the model, and averages |
| the softmax probabilities. |
| |
| Returns (class_name, prob, full_probs) where full_probs is shape (6,). |
| """ |
| n_ch, n_samples = long_signal.shape |
| if n_ch != 2: |
| raise ValueError(f"Expected 2-channel signal, got {n_ch}") |
| if n_samples < WINDOW_SAMPLES: |
| raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}") |
| rng = np.random.default_rng(seed) |
| starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views) |
| accum = torch.zeros(model.num_classes, device=device) |
| for s in starts: |
| window = zscore(long_signal[:, s:s + WINDOW_SAMPLES]) |
| x = torch.from_numpy(window).float().unsqueeze(0).to(device) |
| with torch.no_grad(): |
| probs = F.softmax(model(x), dim=-1)[0] |
| accum += probs |
| probs_avg = accum / n_views |
| idx = int(probs_avg.argmax().item()) |
| return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy() |
|
|
|
|
| def demo(): |
| print("Loading model from HF...") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = load_model(device) |
| print(f"Loaded {model.__class__.__name__} on {device}") |
| print(f"Classes: {model.class_names}") |
| print(f"Params: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| print("\n--- single-window demo (random input) ---") |
| fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32)) |
| cls, prob = predict_single(model, fake_window, device=device) |
| print(f"prediction: {cls} ({prob:.1%})") |
|
|
| print("\n--- TTA-7 demo (random 30 s input) ---") |
| fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32) |
| cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device) |
| print(f"prediction: {cls} ({prob:.1%})") |
| print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}") |
|
|
|
|
| if __name__ == "__main__": |
| demo() |
|
|