File size: 2,713 Bytes
7a63dcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Visibility-restricted encoder attention (CoRe-ECG reconstruction encoder)."""

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from mae.mlp import MLP


def build_encoder_attn_bias(v: torch.Tensor) -> torch.Tensor:
    """
    Additive attention bias (B, L, L): visible queries attend to visible keys;
    non-visible queries use identity self-attention row.
    """
    B, L = v.shape
    dtype = torch.float32
    pair_ok = v.unsqueeze(2) & v.unsqueeze(1)
    bias = torch.zeros(B, L, L, device=v.device, dtype=dtype)
    bias.masked_fill_(~pair_ok, -1e4)
    not_q = ~v
    eye = torch.eye(L, device=v.device, dtype=dtype).unsqueeze(0)
    off_diag = torch.full((1, L, L), -1e4, device=v.device, dtype=dtype)
    identity_row = torch.where(eye > 0.5, torch.zeros_like(off_diag), off_diag)
    identity_row = identity_row.expand(B, -1, -1)
    bias = torch.where(not_q.unsqueeze(-1), identity_row, bias)
    return bias


class EncoderAttentionBlock(nn.Module):
    def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float):
        super().__init__()
        if dim % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads")
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim**-0.5
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, 3 * dim)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(dropout)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout)

    def forward(self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor) -> torch.Tensor:
        x_norm = self.norm1(x)
        out = self._visibility_attn(x_norm, v, attn_bias)
        x = x + out
        x = x + self.mlp(self.norm2(x))
        x = x * v.unsqueeze(-1).to(x.dtype)
        return x

    def _visibility_attn(
        self, x: torch.Tensor, v: torch.Tensor, attn_bias: torch.Tensor
    ) -> torch.Tensor:
        B, L, D = x.shape
        H, Dh = self.n_heads, self.head_dim
        qkv = self.qkv(x).reshape(B, L, 3, H, Dh).permute(2, 0, 3, 1, 4)
        q, k, val = qkv[0], qkv[1], qkv[2]
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        scores = scores + attn_bias.unsqueeze(1)
        attn = F.softmax(scores, dim=-1)
        attn = attn * v.unsqueeze(1).unsqueeze(-1)
        attn_sum = attn.sum(dim=-1, keepdim=True)
        attn = torch.where(attn_sum > 0, attn / attn_sum.clamp_min(1e-8), attn)
        attn = self.attn_drop(attn)
        out = torch.matmul(attn, val).transpose(1, 2).reshape(B, L, D)
        return self.proj(out)