Spaces:
Sleeping
Sleeping
| """Decoder transformer blocks (full-sequence self-attention).""" | |
| from __future__ import annotations | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mae.mlp import MLP | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, dim: int, n_heads: int, mlp_ratio: float, dropout: float): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.attn = nn.MultiheadAttention( | |
| dim, n_heads, dropout=dropout, batch_first=True | |
| ) | |
| self.mlp = MLP(dim, int(dim * mlp_ratio), dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x2 = self.norm1(x) | |
| a, _ = self.attn(x2, x2, x2, need_weights=False) | |
| x = x + a | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |