Spaces:
Sleeping
Sleeping
| """Visibility-restricted encoder attention (CoRe-ECG reconstruction encoder).""" | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mae.mlp import MLP | |
| def build_encoder_attn_bias(v: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Additive attention bias (B, L, L): visible queries attend to visible keys; | |
| non-visible queries use identity self-attention row. | |
| """ | |
| B, L = v.shape | |
| dtype = torch.float32 | |
| pair_ok = v.unsqueeze(2) & v.unsqueeze(1) | |
| bias = torch.zeros(B, L, L, device=v.device, dtype=dtype) | |
| bias.masked_fill_(~pair_ok, -1e4) | |
| not_q = ~v | |
| eye = torch.eye(L, device=v.device, dtype=dtype).unsqueeze(0) | |
| off_diag = torch.full((1, L, L), -1e4, device=v.device, dtype=dtype) | |
| identity_row = torch.where(eye > 0.5, torch.zeros_like(off_diag), off_diag) | |
| identity_row = identity_row.expand(B, -1, -1) | |
| bias = torch.where(not_q.unsqueeze(-1), identity_row, bias) | |
| return bias | |
| class EncoderAttentionBlock(nn.Module): | |
| def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float): | |
| super().__init__() | |
| if dim % n_heads != 0: | |
| raise ValueError("d_model must be divisible by n_heads") | |
| self.n_heads = n_heads | |
| self.head_dim = dim // n_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.qkv = nn.Linear(dim, 3 * dim) | |
| self.proj = nn.Linear(dim, dim) | |
| self.attn_drop = nn.Dropout(dropout) | |
| self.mlp = MLP(dim, int(dim * mlp_ratio), dropout) | |
| def forward(self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor: | |
| x_norm = self.norm1(x) | |
| out = self._visibility_attn(x_norm, v, attn_bias) | |
| x = x + out | |
| x = x + self.mlp(self.norm2(x)) | |
| x = x * v.unsqueeze(-1).to(x.dtype) | |
| return x | |
| def _visibility_attn( | |
| self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor | |
| ) -> torch.Tensor: | |
| B, L, D = x.shape | |
| H, Dh = self.n_heads, self.head_dim | |
| qkv = self.qkv(x).reshape(B, L, 3, H, Dh).permute(2, 0, 3, 1, 4) | |
| q, k, val = qkv[0], qkv[1], qkv[2] | |
| scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
| scores = scores + attn_bias.unsqueeze(1) | |
| attn = F.softmax(scores, dim=-1) | |
| attn = attn * v.unsqueeze(1).unsqueeze(-1) | |
| attn_sum = attn.sum(dim=-1, keepdim=True) | |
| attn = torch.where(attn_sum > 0, attn / attn_sum.clamp_min(1e-8), attn) | |
| attn = self.attn_drop(attn) | |
| out = torch.matmul(attn, val).transpose(1, 2).reshape(B, L, D) | |
| return self.proj(out) | |