"""Decoder transformer blocks (full-sequence self-attention).""" from __future__ import annotations import torch.nn as nn import torch.nn.functional as F from mae.mlp import MLP class DecoderBlock(nn.Module): def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float): super().__init__() self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention( dim, n_heads, dropout=dropout, batch_first=True ) self.mlp = MLP(dim, int(dim * mlp_ratio), dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x2 = self.norm1(x) a, _ = self.attn(x2, x2, x2, need_weights=False) x = x + a x = x + self.mlp(self.norm2(x)) return x