rmxjck's picture
Initial release
a888b0c verified
# SPDX-License-Identifier: MIT
"""Self-contained HTF beat classifier for LTAF (N / A / V).
Three parallel streams (time + frequency + history) → fused → MLP head.
Inspired by alberto-rota/PAC-PVC-Beat-Classifier-for-ECGs (HTF ensemble).
Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_beat_htf.py).
"""
from __future__ import annotations
from pathlib import Path
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
BEAT_CLASS_NAMES = ["N", "A", "V"]
class _ConvBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int, kernel: int = 7,
pool: int = 2, dropout: float = 0.0):
super().__init__()
self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel,
padding=kernel // 2, bias=False)
self.bn = nn.BatchNorm1d(out_ch)
self.act = nn.ReLU(inplace=True)
self.pool = nn.MaxPool1d(pool) if pool > 1 else nn.Identity()
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.drop(self.pool(self.act(self.bn(self.conv(x)))))
class _CNNTrunk(nn.Module):
"""Stack of conv blocks ending in adaptive average pool."""
def __init__(self, in_channels: int, base_channels: int = 32,
n_blocks: int = 5, dropout: float = 0.1):
super().__init__()
layers = []
ch = in_channels
out_ch = base_channels
for _ in range(n_blocks):
layers.append(_ConvBlock(ch, out_ch, kernel=7, pool=2, dropout=dropout))
ch = out_ch
out_ch = min(out_ch * 2, 256)
self.net = nn.Sequential(*layers)
self.pool = nn.AdaptiveAvgPool1d(1)
self.out_channels = ch
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.pool(self.net(x)).squeeze(-1)
class EcgBeatHTFClassifier(nn.Module):
"""HTF ensemble: time + frequency + history → MLP head."""
def __init__(
self,
num_classes: int = 3,
class_names: List[str] = BEAT_CLASS_NAMES,
n_channels: int = 2,
window_samples: int = 256,
history_k: int = 5,
history_use_labels: bool = True,
time_base_channels: int = 32,
freq_base_channels: int = 32,
head_hidden: int = 128,
dropout: float = 0.1,
):
super().__init__()
assert len(class_names) == num_classes
self.num_classes = num_classes
self.class_names = list(class_names)
self.n_channels = n_channels
self.window_samples = window_samples
self.history_k = history_k
self.history_use_labels = history_use_labels
self.time_trunk = _CNNTrunk(n_channels, time_base_channels,
n_blocks=5, dropout=dropout)
self.freq_trunk = _CNNTrunk(n_channels, freq_base_channels,
n_blocks=4, dropout=dropout)
history_in = history_k
if history_use_labels:
history_in += history_k * num_classes
self.history_net = nn.Sequential(
nn.Linear(history_in, 64), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(64, 64), nn.ReLU(inplace=True),
)
fused_dim = self.time_trunk.out_channels + self.freq_trunk.out_channels + 64
self.head = nn.Sequential(
nn.Linear(fused_dim, head_hidden),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(head_hidden, num_classes),
)
def _compute_freq(self, x: torch.Tensor) -> torch.Tensor:
spec = torch.fft.rfft(x, dim=-1)
return torch.log1p(spec.abs())
def forward(
self,
x_time: torch.Tensor,
rr_history: torch.Tensor,
label_history: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
x_time: (B, 2, 256) — raw 2-lead ECG window centered on R-peak.
rr_history: (B, K) — RR intervals to preceding K beats, in
seconds. Set 0.0 for missing (record start).
label_history: (B, K) int64 — preceding K beat labels (0=N, 1=A,
2=V). Use -1 for missing. Used iff `history_use_labels=True`.
At inference time, can be filled with previous predictions
(autoregressive) or zeros (recall slightly degrades).
Returns:
(B, num_classes) logits.
"""
time_feat = self.time_trunk(x_time)
freq_feat = self.freq_trunk(self._compute_freq(x_time))
if self.history_use_labels:
assert label_history is not None
B, K = label_history.shape
valid = (label_history >= 0).float().unsqueeze(-1)
idx = label_history.clamp(min=0)
one_hot = F.one_hot(idx, num_classes=self.num_classes).float() * valid
hist_in = torch.cat([rr_history, one_hot.reshape(B, -1)], dim=-1)
else:
hist_in = rr_history
hist_feat = self.history_net(hist_in)
fused = torch.cat([time_feat, freq_feat, hist_feat], dim=-1)
return self.head(fused)
@classmethod
def load(cls, path: str | Path, device: str = "cpu") -> "EcgBeatHTFClassifier":
ckpt = torch.load(path, map_location=device, weights_only=False)
model = cls(
num_classes=ckpt["num_classes"],
class_names=ckpt["class_names"],
n_channels=ckpt["n_channels"],
window_samples=ckpt["window_samples"],
history_k=ckpt["history_k"],
history_use_labels=ckpt["history_use_labels"],
)
model.load_state_dict(ckpt["state_dict"])
model.to(device).eval()
return model