"""Visibility-restricted encoder attention (CoRe-ECG reconstruction encoder).""" from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F from mae.mlp import MLP def build_encoder_attn_bias(v: torch.Tensor) -> torch.Tensor: """ Additive attention bias (B, L, L): visible queries attend to visible keys; non-visible queries use identity self-attention row. """ B, L = v.shape dtype = torch.float32 pair_ok = v.unsqueeze(2) & v.unsqueeze(1) bias = torch.zeros(B, L, L, device=v.device, dtype=dtype) bias.masked_fill_(~pair_ok, -1e4) not_q = ~v eye = torch.eye(L, device=v.device, dtype=dtype).unsqueeze(0) off_diag = torch.full((1, L, L), -1e4, device=v.device, dtype=dtype) identity_row = torch.where(eye > 0.5, torch.zeros_like(off_diag), off_diag) identity_row = identity_row.expand(B, -1, -1) bias = torch.where(not_q.unsqueeze(-1), identity_row, bias) return bias class EncoderAttentionBlock(nn.Module): def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float): super().__init__() if dim % n_heads != 0: raise ValueError("d_model must be divisible by n_heads") self.n_heads = n_heads self.head_dim = dim // n_heads self.scale = self.head_dim**-0.5 self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.qkv = nn.Linear(dim, 3 * dim) self.proj = nn.Linear(dim, dim) self.attn_drop = nn.Dropout(dropout) self.mlp = MLP(dim, int(dim * mlp_ratio), dropout) def forward(self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor: x_norm = self.norm1(x) out = self._visibility_attn(x_norm, v, attn_bias) x = x + out x = x + self.mlp(self.norm2(x)) x = x * v.unsqueeze(-1).to(x.dtype) return x def _visibility_attn( self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor ) -> torch.Tensor: B, L, D = x.shape H, Dh = self.n_heads, self.head_dim qkv = self.qkv(x).reshape(B, L, 3, H, Dh).permute(2, 0, 3, 1, 4) q, k, val = qkv[0], qkv[1], qkv[2] scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale scores = scores + attn_bias.unsqueeze(1) attn = F.softmax(scores, dim=-1) attn = attn * v.unsqueeze(1).unsqueeze(-1) attn_sum = attn.sum(dim=-1, keepdim=True) attn = torch.where(attn_sum > 0, attn / attn_sum.clamp_min(1e-8), attn) attn = self.attn_drop(attn) out = torch.matmul(attn, val).transpose(1, 2).reshape(B, L, D) return self.proj(out)