"""ECG masked autoencoder: patch embedding, visibility-restricted encoder, full decoder.""" from __future__ import annotations from typing import Any import torch import torch.nn as nn from torch import Tensor from mae.config import MAEConfig from mae.decoder import DecoderBlock from mae.encoder import EncoderAttentionBlock, build_encoder_attn_bias from mae.losses import masked_reconstruction_loss from mae.stdm import sample_stdm_masks class ECGDataMAE(nn.Module): """ Reconstruction-only MAE with STDM masking and visibility-restricted encoder. """ def __init__(self, config: MAEConfig): super().__init__() self.cfg = config C, N, P = config.num_leads, config.num_patches, config.patch_size L = C * N d = config.d_model self.patch_embed = nn.Linear(P, d) self.pos_embed = nn.Parameter(torch.zeros(1, L, d)) nn.init.trunc_normal_(self.pos_embed, std=0.02) self.encoder_blocks = nn.ModuleList( [ EncoderAttentionBlock(d, config.n_heads, config.mlp_ratio, config.dropout) for _ in range(config.encoder_layers) ] ) self.decoder_embed = nn.Linear(d, d) self.decoder_pos = nn.Parameter(torch.zeros(1, L, d)) nn.init.trunc_normal_(self.decoder_pos, std=0.02) self.decoder_mask_token = nn.Parameter(torch.zeros(1, 1, d)) nn.init.trunc_normal_(self.decoder_mask_token, std=0.02) self.decoder_blocks = nn.ModuleList( [ DecoderBlock(d, config.n_heads, config.mlp_ratio, config.dropout) for _ in range(config.decoder_layers) ] ) self.decoder_norm = nn.LayerNorm(d) self.head = nn.Linear(d, P) self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def patchify(self, x: torch.Tensor) -> torch.Tensor: """(B, C, T) -> (B, C, N, P)""" B, C, T = x.shape N, P = self.cfg.num_patches, self.cfg.patch_size if C != self.cfg.num_leads or T != N * P: raise ValueError(f"expected (B,{self.cfg.num_leads},{N*P}), got {x.shape}") return x.reshape(B, C, N, P) @staticmethod def _validate_external_masks( V: Tensor, M: Tensor, D: Tensor, *, B: int, C: int, N: int ) -> None: for name, t in ("V", V), ("M", M), ("D", D): if t.shape != (B, C, N): raise ValueError(f"{name} must have shape (B, C, N)={(B, C, N)}, got {tuple(t.shape)}") if t.dtype != torch.bool: raise ValueError(f"{name} must be bool, got {t.dtype}") if (V & M).any() or (V & D).any() or (M & D).any(): raise ValueError("V, M, D must be mutually exclusive (no cell in more than one set).") if not (V | M | D).all(): raise ValueError("V, M, D must partition the grid: each (b, c, n) in exactly one of V, M, D.") def forward( self, x: torch.Tensor, *, V: Tensor | None = None, M: Tensor | None = None, D: Tensor | None = None, generator: torch.Generator | None = None, return_loss: bool = True, ) -> dict[str, Any]: """ If ``V``, ``M``, and ``D`` are all omitted, sample STDM masks (training default). If provided, all three must be given: boolean ``(B, C, N)`` masks partitioning the patch grid. """ B, C, T = x.shape N = self.cfg.num_patches patches = self.patchify(x) any_mask = V is not None or M is not None or D is not None if any_mask: if V is None or M is None or D is None: raise ValueError("Provide all three of V, M, D or omit all to sample STDM.") self._validate_external_masks(V, M, D, B=B, C=C, N=N) else: V, M, D = sample_stdm_masks( B, C, N, p_time=self.cfg.p_time, p_lead=self.cfg.p_lead, k_visible=self.cfg.num_visible_leads, device=x.device, generator=generator, ) L = C * N P = self.cfg.patch_size flat = patches.reshape(B, L, P) tok = self.patch_embed(flat) + self.pos_embed v_flat = V.reshape(B, L) tok = tok * v_flat.unsqueeze(-1).to(tok.dtype) attn_bias = build_encoder_attn_bias(v_flat) h = tok for blk in self.encoder_blocks: h = blk(h, v_flat, attn_bias) dec_in = self.decoder_embed(h) dec_in = dec_in * v_flat.unsqueeze(-1).to(dec_in.dtype) dec_in = dec_in + self.decoder_mask_token * (~v_flat).unsqueeze(-1).to(dec_in.dtype) dec_in = dec_in + self.decoder_pos z = dec_in for blk in self.decoder_blocks: z = blk(z) z = self.decoder_norm(z) pred_flat = self.head(z) pred = pred_flat.reshape(B, C, N, P) out: dict[str, Any] = {"pred": pred, "V": V, "M": M, "D": D, "patches": patches} if return_loss: out["loss"] = masked_reconstruction_loss(pred, patches, M) return out