rmxjck's picture
Initial release
464d595 verified
# SPDX-License-Identifier: MIT
"""Self-contained RhythmResNet1D for LTAF rhythm classification.
Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_rhythm_scratch.py)
so the model can be loaded with no external project imports.
"""
from __future__ import annotations
from pathlib import Path
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
RHYTHM_CLASS_NAMES = ["NSR", "AFIB", "SBR", "AB", "SVTA", "B"]
class _BasicBlock1D(nn.Module):
"""Two-conv residual block with optional stride-2 downsample."""
def __init__(self, in_c: int, out_c: int, kernel: int = 7, stride: int = 1,
dropout: float = 0.1):
super().__init__()
pad = kernel // 2
self.conv1 = nn.Conv1d(in_c, out_c, kernel_size=kernel, stride=stride,
padding=pad, bias=False)
self.bn1 = nn.BatchNorm1d(out_c)
self.conv2 = nn.Conv1d(out_c, out_c, kernel_size=kernel, stride=1,
padding=pad, bias=False)
self.bn2 = nn.BatchNorm1d(out_c)
self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
if stride != 1 or in_c != out_c:
self.proj = nn.Sequential(
nn.Conv1d(in_c, out_c, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm1d(out_c),
)
else:
self.proj = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = self.proj(x)
h = F.relu(self.bn1(self.conv1(x)), inplace=True)
h = self.drop(h)
h = self.bn2(self.conv2(h))
return F.relu(h + identity, inplace=True)
class RhythmResNet1D(nn.Module):
"""1D ResNet — stem + 4 stages, each stage halves time and doubles channels."""
def __init__(
self,
num_classes: int = 6,
class_names: List[str] = RHYTHM_CLASS_NAMES,
n_channels: int = 2,
base_channels: int = 64,
blocks_per_stage: int = 2,
stem_kernel: int = 15,
block_kernel: int = 7,
dropout: float = 0.2,
):
super().__init__()
assert len(class_names) == num_classes
self.num_classes = num_classes
self.class_names = list(class_names)
self.n_channels = n_channels
self.base_channels = base_channels
self.blocks_per_stage = blocks_per_stage
self.stem = nn.Sequential(
nn.Conv1d(n_channels, base_channels, kernel_size=stem_kernel,
stride=2, padding=stem_kernel // 2, bias=False),
nn.BatchNorm1d(base_channels),
nn.ReLU(inplace=True),
nn.MaxPool1d(2),
)
stages = []
in_c = base_channels
out_c = base_channels
for s in range(4):
for b in range(blocks_per_stage):
stride = 2 if (b == 0 and s > 0) else 1
stages.append(_BasicBlock1D(in_c, out_c, kernel=block_kernel,
stride=stride, dropout=dropout))
in_c = out_c
out_c = min(out_c * 2, 512)
self.stages = nn.Sequential(*stages)
self.pool = nn.AdaptiveAvgPool1d(1)
self.head = nn.Sequential(
nn.Linear(in_c, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.stem(x)
h = self.stages(h)
feat = self.pool(h).squeeze(-1)
return self.head(feat)
@classmethod
def load(cls, path: str | Path, device: str = "cpu") -> "RhythmResNet1D":
ckpt = torch.load(path, map_location=device, weights_only=False)
model = cls(
num_classes=ckpt["num_classes"], class_names=ckpt["class_names"],
n_channels=ckpt["n_channels"],
base_channels=ckpt.get("base_channels", 64),
blocks_per_stage=ckpt.get("blocks_per_stage", 2),
)
model.load_state_dict(ckpt["state_dict"])
model.to(device).eval()
return model