Spaces:
Sleeping
Sleeping
File size: 2,713 Bytes
7a63dcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | """Visibility-restricted encoder attention (CoRe-ECG reconstruction encoder)."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mae.mlp import MLP
def build_encoder_attn_bias(v: torch.Tensor) -> torch.Tensor:
"""
Additive attention bias (B, L, L): visible queries attend to visible keys;
non-visible queries use identity self-attention row.
"""
B, L = v.shape
dtype = torch.float32
pair_ok = v.unsqueeze(2) & v.unsqueeze(1)
bias = torch.zeros(B, L, L, device=v.device, dtype=dtype)
bias.masked_fill_(~pair_ok, -1e4)
not_q = ~v
eye = torch.eye(L, device=v.device, dtype=dtype).unsqueeze(0)
off_diag = torch.full((1, L, L), -1e4, device=v.device, dtype=dtype)
identity_row = torch.where(eye > 0.5, torch.zeros_like(off_diag), off_diag)
identity_row = identity_row.expand(B, -1, -1)
bias = torch.where(not_q.unsqueeze(-1), identity_row, bias)
return bias
class EncoderAttentionBlock(nn.Module):
def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float):
super().__init__()
if dim % n_heads != 0:
raise ValueError("d_model must be divisible by n_heads")
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.scale = self.head_dim**-0.5
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, 3 * dim)
self.proj = nn.Linear(dim, dim)
self.attn_drop = nn.Dropout(dropout)
self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)
def forward(self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor:
x_norm = self.norm1(x)
out = self._visibility_attn(x_norm, v, attn_bias)
x = x + out
x = x + self.mlp(self.norm2(x))
x = x * v.unsqueeze(-1).to(x.dtype)
return x
def _visibility_attn(
self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor
) -> torch.Tensor:
B, L, D = x.shape
H, Dh = self.n_heads, self.head_dim
qkv = self.qkv(x).reshape(B, L, 3, H, Dh).permute(2, 0, 3, 1, 4)
q, k, val = qkv[0], qkv[1], qkv[2]
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
scores = scores + attn_bias.unsqueeze(1)
attn = F.softmax(scores, dim=-1)
attn = attn * v.unsqueeze(1).unsqueeze(-1)
attn_sum = attn.sum(dim=-1, keepdim=True)
attn = torch.where(attn_sum > 0, attn / attn_sum.clamp_min(1e-8), attn)
attn = self.attn_drop(attn)
out = torch.matmul(attn, val).transpose(1, 2).reshape(B, L, D)
return self.proj(out)
|