File size: 4,181 Bytes
464d595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
"""Inference example for the LTAF ECG rhythm classifier.

Two modes:
- Single-window: pass a (B, 2, 1280) z-scored 10 s @ 128 Hz tensor.
- TTA-7 (recommended, +4 pp F1): pass a longer signal slice and the
  function will pull 7 random 10 s windows from it and soft-vote.

Usage:
    .venv/bin/python inference.py
"""

from __future__ import annotations

from pathlib import Path
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

from model import RHYTHM_CLASS_NAMES, RhythmResNet1D


WINDOW_SECONDS = 10
SOURCE_HZ = 128
WINDOW_SAMPLES = WINDOW_SECONDS * SOURCE_HZ  # 1280


def load_model(device: str = "cpu") -> RhythmResNet1D:
    """Download the checkpoint from HF and load it."""
    ckpt_path = hf_hub_download(
        "rmxjck/ltaf-ecg-rhythm-classifier",
        "best_classifier.pt",
    )
    return RhythmResNet1D.load(ckpt_path, device=device)


def zscore(window: np.ndarray) -> np.ndarray:
    """Per-channel z-score a (C, L) array."""
    mean = window.mean(axis=-1, keepdims=True)
    std = window.std(axis=-1, keepdims=True)
    return ((window - mean) / (std + 1e-6)).astype(np.float32, copy=False)


def predict_single(
    model: RhythmResNet1D,
    window: np.ndarray,
    device: str = "cpu",
) -> Tuple[str, float]:
    """Predict on one (2, 1280) z-scored window. Returns (class_name, prob)."""
    if window.shape != (2, WINDOW_SAMPLES):
        raise ValueError(f"Expected (2, {WINDOW_SAMPLES}), got {window.shape}")
    x = torch.from_numpy(window).float().unsqueeze(0).to(device)
    with torch.no_grad():
        probs = F.softmax(model(x), dim=-1)[0]
    idx = int(probs.argmax().item())
    return model.class_names[idx], float(probs[idx].item())


def predict_tta(
    model: RhythmResNet1D,
    long_signal: np.ndarray,
    n_views: int = 7,
    device: str = "cpu",
    seed: int = 42,
) -> Tuple[str, float, np.ndarray]:
    """TTA-soft-voting prediction over a longer (2, L) signal.

    Samples ``n_views`` random 10 s windows from ``long_signal`` (L >= 1280),
    z-scores each independently, runs them through the model, and averages
    the softmax probabilities.

    Returns (class_name, prob, full_probs) where full_probs is shape (6,).
    """
    n_ch, n_samples = long_signal.shape
    if n_ch != 2:
        raise ValueError(f"Expected 2-channel signal, got {n_ch}")
    if n_samples < WINDOW_SAMPLES:
        raise ValueError(f"Need at least {WINDOW_SAMPLES} samples, got {n_samples}")
    rng = np.random.default_rng(seed)
    starts = rng.integers(0, n_samples - WINDOW_SAMPLES + 1, size=n_views)
    accum = torch.zeros(model.num_classes, device=device)
    for s in starts:
        window = zscore(long_signal[:, s:s + WINDOW_SAMPLES])
        x = torch.from_numpy(window).float().unsqueeze(0).to(device)
        with torch.no_grad():
            probs = F.softmax(model(x), dim=-1)[0]
        accum += probs
    probs_avg = accum / n_views
    idx = int(probs_avg.argmax().item())
    return model.class_names[idx], float(probs_avg[idx].item()), probs_avg.cpu().numpy()


def demo():
    print("Loading model from HF...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = load_model(device)
    print(f"Loaded {model.__class__.__name__} on {device}")
    print(f"Classes: {model.class_names}")
    print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

    # Synthetic example: random noise (will get garbage prediction).
    print("\n--- single-window demo (random input) ---")
    fake_window = zscore(np.random.randn(2, WINDOW_SAMPLES).astype(np.float32))
    cls, prob = predict_single(model, fake_window, device=device)
    print(f"prediction: {cls} ({prob:.1%})")

    print("\n--- TTA-7 demo (random 30 s input) ---")
    fake_long = np.random.randn(2, 30 * SOURCE_HZ).astype(np.float32)
    cls, prob, full = predict_tta(model, fake_long, n_views=7, device=device)
    print(f"prediction: {cls} ({prob:.1%})")
    print(f"all class probs: {dict(zip(model.class_names, [round(p, 3) for p in full.tolist()]))}")


if __name__ == "__main__":
    demo()