| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from functools import partial |
| from typing import Literal |
|
|
| class DepthConv1d(nn.Module): |
| """Depthwise conv1d. |
| |
| Args: |
| causal: use a causal convolution mask. |
| positive: constrain kernel to be non-negative. |
| blockwise: single kernel shared across all channels. |
| |
| Shape: |
| input: (*, L, C) |
| output: (*, L, C) |
| """ |
|
|
| attn_mask: torch.Tensor |
|
|
| def __init__( |
| self, |
| embed_dim: int, |
| kernel_size: int, |
| causal: bool = False, |
| positive: bool = False, |
| blockwise: bool = False, |
| bias: bool = True, |
| ): |
| assert not causal or kernel_size % 2 == 1, "causal conv requires odd kernel" |
|
|
| super().__init__() |
| self.embed_dim = embed_dim |
| self.kernel_size = kernel_size |
| self.causal = causal |
| self.positive = positive |
| self.blockwise = blockwise |
|
|
| if blockwise: |
| weight_shape = (1, 1, kernel_size) |
| else: |
| weight_shape = (embed_dim, 1, kernel_size) |
| self.weight = nn.Parameter(torch.empty(weight_shape)) |
|
|
| if bias: |
| self.bias = nn.Parameter(torch.empty(embed_dim)) |
| else: |
| self.register_parameter("bias", None) |
|
|
| attn_mask = torch.ones(kernel_size) |
| if self.causal: |
| attn_mask[kernel_size // 2 + 1 :] = 0.0 |
| self.register_buffer("attn_mask", attn_mask) |
|
|
| self.init_weights() |
|
|
| def init_weights(self): |
| nn.init.trunc_normal_(self.weight, std=0.02) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| |
| |
| *leading_dims, L, C = input.shape |
| assert C == self.embed_dim |
|
|
| |
| input = input.reshape(-1, L, C).transpose(1, 2) |
|
|
| weight = self.weight * self.attn_mask |
| if self.positive: |
| weight = weight.abs() |
| if self.blockwise: |
| weight = weight.expand((self.embed_dim, 1, self.kernel_size)) |
|
|
| output = F.conv1d( |
| input, weight, self.bias, padding="same", groups=self.embed_dim |
| ) |
|
|
| output = output.transpose(1, 2) |
| output = output.reshape(leading_dims + [L, C]) |
| return output |
|
|
| def extra_repr(self): |
| return ( |
| f"{self.embed_dim}, kernel_size={self.kernel_size}, " |
| f"causal={self.causal}, positive={self.positive}, " |
| f"blockwise={self.blockwise}, bias={self.bias is not None}" |
| ) |
|
|
|
|
| class LinearPoolLatent(nn.Module): |
| """Learned linear pooling over a set of features. |
| |
| Shape: |
| input: (*, L, C) |
| output: (*, C) |
| """ |
|
|
| def __init__(self, embed_dim: int, feat_size: int, positive: bool = False): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.feat_size = feat_size |
| self.positive = positive |
|
|
| self.weight = nn.Parameter(torch.empty(feat_size, embed_dim)) |
| self.init_weights() |
|
|
| def init_weights(self): |
| nn.init.trunc_normal_(self.weight, std=0.02) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| |
| |
| weight = self.weight |
| if self.positive: |
| weight = weight.abs() |
| input = torch.sum(input * weight, dim=-2) |
| return input |
|
|
| def extra_repr(self): |
| return f"{self.embed_dim}, feat_size={self.feat_size}, positive={self.positive}" |
|
|
|
|
| class AttentionPoolLatent(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| num_heads: int = 8, |
| ): |
| super().__init__() |
| assert embed_dim % num_heads == 0 |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.query = nn.Parameter(torch.zeros(embed_dim)) |
| self.kv = nn.Linear(embed_dim, 2 * embed_dim) |
| self.proj = nn.Linear(embed_dim, embed_dim) |
| self.init_weights() |
|
|
| def init_weights(self): |
| nn.init.trunc_normal_(self.query, std=0.02) |
|
|
| def forward(self, x: torch.Tensor): |
| *leading_dims, L, C = x.shape |
| x = x.reshape(-1, L, C) |
| N = len(x) |
| h = self.num_heads |
|
|
| |
| q = self.query.expand(N, 1, C) |
| q = q.reshape(N, 1, h, C // h).transpose(1, 2) |
|
|
| |
| kv = self.kv(x) |
| kv = kv.reshape(N, L, 2, h, C // h).permute(2, 0, 3, 1, 4) |
| k, v = torch.unbind(kv, dim=0) |
|
|
| x = F.scaled_dot_product_attention(q, k, v) |
| x = x.reshape(N, C) |
| x = self.proj(x) |
|
|
| x = x.reshape(leading_dims + [C]) |
| return x |
|
|
| class ConvLinear(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| kernel_size: int = 11, |
| causal: bool = False, |
| ): |
| super().__init__() |
| self.conv = DepthConv1d( |
| in_features, |
| kernel_size=kernel_size, |
| causal=causal, |
| groups=in_features, |
| ) |
| self.fc = nn.Linear(in_features, out_features) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| x = self.conv(x) |
| x = self.fc(x) |
| return x |
|
|
|
|
| class LinearConv(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| kernel_size: int = 11, |
| causal: bool = False, |
| positive: bool = False, |
| blockwise: bool = False, |
| ): |
| super().__init__() |
| self.fc = nn.Linear(in_features, out_features) |
| self.conv = DepthConv1d( |
| out_features, |
| kernel_size=kernel_size, |
| causal=causal, |
| positive=positive, |
| blockwise=blockwise, |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| x = self.fc(x) |
| x = self.conv(x) |
| return x |
| |
| class MultiSubjectConvLinearEncoder(nn.Module): |
| weight: torch.Tensor |
|
|
| def __init__( |
| self, |
| num_subjects: int = 4, |
| feat_dims: tuple[int | tuple[int, int], ...] = (2048,), |
| embed_dim: int = 256, |
| target_dim: int = 1000, |
| hidden_model: nn.Module | None = None, |
| global_pool: Literal["avg", "linear", "attn"] = "avg", |
| encoder_kernel_size: int = 33, |
| decoder_kernel_size: int = 0, |
| encoder_causal: bool = True, |
| encoder_positive: bool = False, |
| encoder_blockwise: bool = False, |
| pool_num_heads: int = 4, |
| with_shared_decoder: bool = True, |
| with_subject_decoders: bool = True, |
| ): |
| assert with_shared_decoder or with_subject_decoders |
|
|
| super().__init__() |
| self.num_subjects = num_subjects |
| self.global_pool = global_pool |
|
|
| |
| feat_dims = [(1, dim) if isinstance(dim, int) else dim for dim in feat_dims] |
| total_feat_size = sum(dim[0] for dim in feat_dims) |
|
|
| self.feat_embeds = nn.ModuleList( |
| [ |
| _make_feat_embed( |
| dim[1], |
| embed_dim, |
| kernel_size=encoder_kernel_size, |
| causal=encoder_causal, |
| positive=encoder_positive, |
| blockwise=encoder_blockwise, |
| ) |
| for dim in feat_dims |
| ] |
| ) |
|
|
| if global_pool == "avg": |
| self.register_module("feat_pool", None) |
| elif global_pool == "linear": |
| self.feat_pool = LinearPoolLatent(embed_dim, total_feat_size) |
| elif global_pool == "attn": |
| self.feat_pool = AttentionPoolLatent( |
| embed_dim, total_feat_size, num_heads=pool_num_heads |
| ) |
| else: |
| raise NotImplementedError(f"global_pool {global_pool} not implemented.") |
|
|
| if hidden_model is not None: |
| self.hidden_model = hidden_model |
| else: |
| self.register_module("hidden_model", None) |
|
|
| if with_shared_decoder: |
| self.shared_decoder = nn.Linear(embed_dim, target_dim) |
| else: |
| self.register_module("shared_decoder", None) |
|
|
| if decoder_kernel_size > 1: |
| decoder_linear = partial(ConvLinear, kernel_size=decoder_kernel_size) |
| else: |
| decoder_linear = nn.Linear |
|
|
| if with_subject_decoders: |
| self.subject_decoders = nn.ModuleList( |
| [decoder_linear(embed_dim, target_dim) for _ in range(num_subjects)] |
| ) |
| else: |
| self.register_module("subject_decoders", None) |
|
|
| self.apply(_init_weights) |
|
|
| def forward(self, features: list[torch.Tensor]): |
| |
|
|
| embed_features: list[torch.Tensor] = [] |
| for feat, feat_embed in zip(features, self.feat_embeds): |
| |
| feat = feat[:, None] if feat.ndim == 3 else feat.transpose(1, 2) |
|
|
| |
| feat = feat_embed(feat) |
| embed_features.append(feat) |
|
|
| if self.global_pool == "avg": |
| embed = sum(feat.mean(dim=1) for feat in embed_features) |
| else: |
| embed = torch.cat(embed_features, dim=1) |
| |
| embed = embed.transpose(1, 2) |
| embed = self.feat_pool(embed) |
|
|
| if self.hidden_model is not None: |
| embed = self.hidden_model(embed) |
|
|
| if self.shared_decoder is not None: |
| shared_output = self.shared_decoder(embed) |
| shared_output = shared_output[:, None].expand(-1, self.num_subjects, -1, -1) |
| else: |
| shared_output = 0.0 |
|
|
| if self.subject_decoders is not None: |
| subject_output = torch.stack( |
| [decoder(embed) for decoder in self.subject_decoders], |
| dim=1, |
| ) |
| else: |
| subject_output = 0.0 |
|
|
| output = subject_output + shared_output |
| return output |
|
|
|
|
| def _make_feat_embed( |
| feat_dim: int = 2048, |
| embed_dim: int = 256, |
| kernel_size: int = 33, |
| causal: bool = True, |
| positive: bool = False, |
| blockwise: bool = False, |
| ) -> nn.Module: |
| if kernel_size > 1: |
| embed = LinearConv( |
| feat_dim, |
| embed_dim, |
| kernel_size=kernel_size, |
| causal=causal, |
| positive=positive, |
| blockwise=blockwise, |
| ) |
| else: |
| embed = nn.Linear(feat_dim, embed_dim) |
| return embed |
|
|
|
|
| def _init_weights(m: nn.Module) -> None: |
| if isinstance(m, (nn.Conv1d, nn.Linear, DepthConv1d)): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)) and m.elementwise_affine: |
| nn.init.constant_(m.weight, 1.0) |
| if hasattr(m, "bias") and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|