# SPDX-License-Identifier: MIT """Self-contained RhythmResNet1D for LTAF rhythm classification. Vendored from rmxjck/TSLM-Arena (src/models/ts_llm/ecg_rhythm_scratch.py) so the model can be loaded with no external project imports. """ from __future__ import annotations from pathlib import Path from typing import List import torch import torch.nn as nn import torch.nn.functional as F RHYTHM_CLASS_NAMES = ["NSR", "AFIB", "SBR", "AB", "SVTA", "B"] class _BasicBlock1D(nn.Module): """Two-conv residual block with optional stride-2 downsample.""" def __init__(self, in_c: int, out_c: int, kernel: int = 7, stride: int = 1, dropout: float = 0.1): super().__init__() pad = kernel // 2 self.conv1 = nn.Conv1d(in_c, out_c, kernel_size=kernel, stride=stride, padding=pad, bias=False) self.bn1 = nn.BatchNorm1d(out_c) self.conv2 = nn.Conv1d(out_c, out_c, kernel_size=kernel, stride=1, padding=pad, bias=False) self.bn2 = nn.BatchNorm1d(out_c) self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity() if stride != 1 or in_c != out_c: self.proj = nn.Sequential( nn.Conv1d(in_c, out_c, kernel_size=1, stride=stride, bias=False), nn.BatchNorm1d(out_c), ) else: self.proj = nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: identity = self.proj(x) h = F.relu(self.bn1(self.conv1(x)), inplace=True) h = self.drop(h) h = self.bn2(self.conv2(h)) return F.relu(h + identity, inplace=True) class RhythmResNet1D(nn.Module): """1D ResNet — stem + 4 stages, each stage halves time and doubles channels.""" def __init__( self, num_classes: int = 6, class_names: List[str] = RHYTHM_CLASS_NAMES, n_channels: int = 2, base_channels: int = 64, blocks_per_stage: int = 2, stem_kernel: int = 15, block_kernel: int = 7, dropout: float = 0.2, ): super().__init__() assert len(class_names) == num_classes self.num_classes = num_classes self.class_names = list(class_names) self.n_channels = n_channels self.base_channels = base_channels self.blocks_per_stage = blocks_per_stage self.stem = nn.Sequential( nn.Conv1d(n_channels, base_channels, kernel_size=stem_kernel, stride=2, padding=stem_kernel // 2, bias=False), nn.BatchNorm1d(base_channels), nn.ReLU(inplace=True), nn.MaxPool1d(2), ) stages = [] in_c = base_channels out_c = base_channels for s in range(4): for b in range(blocks_per_stage): stride = 2 if (b == 0 and s > 0) else 1 stages.append(_BasicBlock1D(in_c, out_c, kernel=block_kernel, stride=stride, dropout=dropout)) in_c = out_c out_c = min(out_c * 2, 512) self.stages = nn.Sequential(*stages) self.pool = nn.AdaptiveAvgPool1d(1) self.head = nn.Sequential( nn.Linear(in_c, 128), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(128, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.stem(x) h = self.stages(h) feat = self.pool(h).squeeze(-1) return self.head(feat) @classmethod def load(cls, path: str | Path, device: str = "cpu") -> "RhythmResNet1D": ckpt = torch.load(path, map_location=device, weights_only=False) model = cls( num_classes=ckpt["num_classes"], class_names=ckpt["class_names"], n_channels=ckpt["n_channels"], base_channels=ckpt.get("base_channels", 64), blocks_per_stage=ckpt.get("blocks_per_stage", 2), ) model.load_state_dict(ckpt["state_dict"]) model.to(device).eval() return model