import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from typing import Literal class DepthConv1d(nn.Module): """Depthwise conv1d. Args: causal: use a causal convolution mask. positive: constrain kernel to be non-negative. blockwise: single kernel shared across all channels. Shape: input: (*, L, C) output: (*, L, C) """ attn_mask: torch.Tensor def __init__( self, embed_dim: int, kernel_size: int, causal: bool = False, positive: bool = False, blockwise: bool = False, bias: bool = True, ): assert not causal or kernel_size % 2 == 1, "causal conv requires odd kernel" super().__init__() self.embed_dim = embed_dim self.kernel_size = kernel_size self.causal = causal self.positive = positive self.blockwise = blockwise if blockwise: weight_shape = (1, 1, kernel_size) else: weight_shape = (embed_dim, 1, kernel_size) self.weight = nn.Parameter(torch.empty(weight_shape)) if bias: self.bias = nn.Parameter(torch.empty(embed_dim)) else: self.register_parameter("bias", None) attn_mask = torch.ones(kernel_size) if self.causal: attn_mask[kernel_size // 2 + 1 :] = 0.0 self.register_buffer("attn_mask", attn_mask) self.init_weights() def init_weights(self): nn.init.trunc_normal_(self.weight, std=0.02) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, input: torch.Tensor) -> torch.Tensor: # input: (*, L, C) # output: (*, L, C) *leading_dims, L, C = input.shape assert C == self.embed_dim # (*, L, C) -> (N, C, L) input = input.reshape(-1, L, C).transpose(1, 2) weight = self.weight * self.attn_mask if self.positive: weight = weight.abs() if self.blockwise: weight = weight.expand((self.embed_dim, 1, self.kernel_size)) output = F.conv1d( input, weight, self.bias, padding="same", groups=self.embed_dim ) output = output.transpose(1, 2) output = output.reshape(leading_dims + [L, C]) return output def extra_repr(self): return ( f"{self.embed_dim}, kernel_size={self.kernel_size}, " f"causal={self.causal}, positive={self.positive}, " f"blockwise={self.blockwise}, bias={self.bias is not None}" ) class LinearPoolLatent(nn.Module): """Learned linear pooling over a set of features. Shape: input: (*, L, C) output: (*, C) """ def __init__(self, embed_dim: int, feat_size: int, positive: bool = False): super().__init__() self.embed_dim = embed_dim self.feat_size = feat_size self.positive = positive self.weight = nn.Parameter(torch.empty(feat_size, embed_dim)) self.init_weights() def init_weights(self): nn.init.trunc_normal_(self.weight, std=0.02) def forward(self, input: torch.Tensor) -> torch.Tensor: # input: (*, L, C) # output: (*, C) weight = self.weight if self.positive: weight = weight.abs() input = torch.sum(input * weight, dim=-2) return input def extra_repr(self): return f"{self.embed_dim}, feat_size={self.feat_size}, positive={self.positive}" class AttentionPoolLatent(nn.Module): def __init__( self, embed_dim: int, num_heads: int = 8, ): super().__init__() assert embed_dim % num_heads == 0 self.embed_dim = embed_dim self.num_heads = num_heads self.query = nn.Parameter(torch.zeros(embed_dim)) self.kv = nn.Linear(embed_dim, 2 * embed_dim) self.proj = nn.Linear(embed_dim, embed_dim) self.init_weights() def init_weights(self): nn.init.trunc_normal_(self.query, std=0.02) def forward(self, x: torch.Tensor): *leading_dims, L, C = x.shape x = x.reshape(-1, L, C) N = len(x) h = self.num_heads # fixed learned query q = self.query.expand(N, 1, C) q = q.reshape(N, 1, h, C // h).transpose(1, 2) # [N, h, 1, C] # keys, values for each input feature map kv = self.kv(x) kv = kv.reshape(N, L, 2, h, C // h).permute(2, 0, 3, 1, 4) # [2, N, h, L, C] k, v = torch.unbind(kv, dim=0) x = F.scaled_dot_product_attention(q, k, v) # [N, h, 1, C] x = x.reshape(N, C) x = self.proj(x) x = x.reshape(leading_dims + [C]) return x class ConvLinear(nn.Module): def __init__( self, in_features: int, out_features: int, kernel_size: int = 11, causal: bool = False, ): super().__init__() self.conv = DepthConv1d( in_features, kernel_size=kernel_size, causal=causal, groups=in_features, ) self.fc = nn.Linear(in_features, out_features) def forward(self, x: torch.Tensor): # x: (N, L, C) x = self.conv(x) x = self.fc(x) return x class LinearConv(nn.Module): def __init__( self, in_features: int, out_features: int, kernel_size: int = 11, causal: bool = False, positive: bool = False, blockwise: bool = False, ): super().__init__() self.fc = nn.Linear(in_features, out_features) self.conv = DepthConv1d( out_features, kernel_size=kernel_size, causal=causal, positive=positive, blockwise=blockwise, ) def forward(self, x: torch.Tensor): # x: (N, L, C) x = self.fc(x) x = self.conv(x) return x class MultiSubjectConvLinearEncoder(nn.Module): weight: torch.Tensor def __init__( self, num_subjects: int = 4, feat_dims: tuple[int | tuple[int, int], ...] = (2048,), embed_dim: int = 256, target_dim: int = 1000, hidden_model: nn.Module | None = None, global_pool: Literal["avg", "linear", "attn"] = "avg", encoder_kernel_size: int = 33, decoder_kernel_size: int = 0, encoder_causal: bool = True, encoder_positive: bool = False, encoder_blockwise: bool = False, pool_num_heads: int = 4, with_shared_decoder: bool = True, with_subject_decoders: bool = True, ): assert with_shared_decoder or with_subject_decoders super().__init__() self.num_subjects = num_subjects self.global_pool = global_pool # list of (nfeats, dim) feat_dims = [(1, dim) if isinstance(dim, int) else dim for dim in feat_dims] total_feat_size = sum(dim[0] for dim in feat_dims) self.feat_embeds = nn.ModuleList( [ _make_feat_embed( dim[1], embed_dim, kernel_size=encoder_kernel_size, causal=encoder_causal, positive=encoder_positive, blockwise=encoder_blockwise, ) for dim in feat_dims ] ) if global_pool == "avg": self.register_module("feat_pool", None) elif global_pool == "linear": self.feat_pool = LinearPoolLatent(embed_dim, total_feat_size) elif global_pool == "attn": self.feat_pool = AttentionPoolLatent( embed_dim, total_feat_size, num_heads=pool_num_heads ) else: raise NotImplementedError(f"global_pool {global_pool} not implemented.") if hidden_model is not None: self.hidden_model = hidden_model else: self.register_module("hidden_model", None) if with_shared_decoder: self.shared_decoder = nn.Linear(embed_dim, target_dim) else: self.register_module("shared_decoder", None) if decoder_kernel_size > 1: decoder_linear = partial(ConvLinear, kernel_size=decoder_kernel_size) else: decoder_linear = nn.Linear if with_subject_decoders: self.subject_decoders = nn.ModuleList( [decoder_linear(embed_dim, target_dim) for _ in range(num_subjects)] ) else: self.register_module("subject_decoders", None) self.apply(_init_weights) def forward(self, features: list[torch.Tensor]): # features: list of (N, T, D_i) or (N, T, L_i, D_i) embed_features: list[torch.Tensor] = [] for feat, feat_embed in zip(features, self.feat_embeds): # view as (N, L_i, T, D_i) feat = feat[:, None] if feat.ndim == 3 else feat.transpose(1, 2) # project to (N, L, T, d) feat = feat_embed(feat) embed_features.append(feat) if self.global_pool == "avg": embed = sum(feat.mean(dim=1) for feat in embed_features) else: embed = torch.cat(embed_features, dim=1) # (N, L, T, d) -> (N, T, L, d) embed = embed.transpose(1, 2) embed = self.feat_pool(embed) if self.hidden_model is not None: embed = self.hidden_model(embed) if self.shared_decoder is not None: shared_output = self.shared_decoder(embed) shared_output = shared_output[:, None].expand(-1, self.num_subjects, -1, -1) else: shared_output = 0.0 if self.subject_decoders is not None: subject_output = torch.stack( [decoder(embed) for decoder in self.subject_decoders], dim=1, ) else: subject_output = 0.0 output = subject_output + shared_output return output def _make_feat_embed( feat_dim: int = 2048, embed_dim: int = 256, kernel_size: int = 33, causal: bool = True, positive: bool = False, blockwise: bool = False, ) -> nn.Module: if kernel_size > 1: embed = LinearConv( feat_dim, embed_dim, kernel_size=kernel_size, causal=causal, positive=positive, blockwise=blockwise, ) else: embed = nn.Linear(feat_dim, embed_dim) return embed def _init_weights(m: nn.Module) -> None: if isinstance(m, (nn.Conv1d, nn.Linear, DepthConv1d)): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)) and m.elementwise_affine: nn.init.constant_(m.weight, 1.0) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias, 0)