| """DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| from .compact_channel_attention import CompactChannelAttention |
| from .conv_mlp import ConvMLP |
| from .norms import ChannelWiseRMSNorm |
|
|
|
|
| class DiCoBlock(nn.Module): |
| """DiCo-style conv block with optional external AdaLN conditioning. |
| |
| Two modes: |
| - Unconditioned (encoder): uses learned per-channel residual gates. |
| - External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m. |
| """ |
|
|
| def __init__( |
| self, |
| channels: int, |
| mlp_ratio: float, |
| *, |
| depthwise_kernel_size: int = 7, |
| use_external_adaln: bool = False, |
| norm_eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
| self.channels = int(channels) |
| self.use_external_adaln = bool(use_external_adaln) |
|
|
| |
| self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) |
| self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) |
|
|
| |
| self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) |
| self.conv2 = nn.Conv2d( |
| self.channels, |
| self.channels, |
| kernel_size=depthwise_kernel_size, |
| padding=depthwise_kernel_size // 2, |
| groups=self.channels, |
| bias=True, |
| ) |
| self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) |
| self.cca = CompactChannelAttention(self.channels) |
|
|
| |
| hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1) |
| self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps) |
|
|
| |
| if not self.use_external_adaln: |
| self.gate_attn = nn.Parameter(torch.zeros(self.channels)) |
| self.gate_mlp = nn.Parameter(torch.zeros(self.channels)) |
|
|
| def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor: |
| b, c = x.shape[:2] |
|
|
| if self.use_external_adaln: |
| if adaln_m is None: |
| raise ValueError( |
| "adaln_m required for externally-conditioned DiCoBlock" |
| ) |
| adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype) |
| scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1) |
| elif adaln_m is not None: |
| raise ValueError("adaln_m must be None for unconditioned DiCoBlock") |
|
|
| residual = x |
|
|
| |
| x_att = self.norm1(x) |
| if self.use_external_adaln: |
| x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) |
| y = self.conv1(x_att) |
| y = self.conv2(y) |
| y = F.silu(y) |
| y = y * self.cca(y) |
| y = self.conv3(y) |
|
|
| if self.use_external_adaln: |
| gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) |
| x = residual + gate_a_view * y |
| else: |
| gate = self.gate_attn.view(1, self.channels, 1, 1).to( |
| dtype=y.dtype, device=y.device |
| ) |
| x = residual + gate * y |
|
|
| |
| residual_mlp = x |
| x_mlp = self.norm2(x) |
| if self.use_external_adaln: |
| x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) |
| y_mlp = self.mlp(x_mlp) |
|
|
| if self.use_external_adaln: |
| gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) |
| x = residual_mlp + gate_m_view * y_mlp |
| else: |
| gate = self.gate_mlp.view(1, self.channels, 1, 1).to( |
| dtype=y_mlp.dtype, device=y_mlp.device |
| ) |
| x = residual_mlp + gate * y_mlp |
|
|
| return x |
|
|