| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from functools import partial
|
| from typing import Literal
|
|
|
| class DepthConv1d(nn.Module):
|
| """Depthwise conv1d.
|
|
|
| Args:
|
| causal: use a causal convolution mask.
|
| positive: constrain kernel to be non-negative.
|
| blockwise: single kernel shared across all channels.
|
|
|
| Shape:
|
| input: (*, L, C)
|
| output: (*, L, C)
|
| """
|
|
|
| attn_mask: torch.Tensor
|
|
|
| def __init__(
|
| self,
|
| embed_dim: int,
|
| kernel_size: int,
|
| causal: bool = False,
|
| positive: bool = False,
|
| blockwise: bool = False,
|
| bias: bool = True,
|
| ):
|
| assert not causal or kernel_size % 2 == 1, "causal conv requires odd kernel"
|
|
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.kernel_size = kernel_size
|
| self.causal = causal
|
| self.positive = positive
|
| self.blockwise = blockwise
|
|
|
| if blockwise:
|
| weight_shape = (1, 1, kernel_size)
|
| else:
|
| weight_shape = (embed_dim, 1, kernel_size)
|
| self.weight = nn.Parameter(torch.empty(weight_shape))
|
|
|
| if bias:
|
| self.bias = nn.Parameter(torch.empty(embed_dim))
|
| else:
|
| self.register_parameter("bias", None)
|
|
|
| attn_mask = torch.ones(kernel_size)
|
| if self.causal:
|
| attn_mask[kernel_size // 2 + 1 :] = 0.0
|
| self.register_buffer("attn_mask", attn_mask)
|
|
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| nn.init.trunc_normal_(self.weight, std=0.02)
|
| if self.bias is not None:
|
| nn.init.zeros_(self.bias)
|
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
| *leading_dims, L, C = input.shape
|
| assert C == self.embed_dim
|
|
|
|
|
| input = input.reshape(-1, L, C).transpose(1, 2)
|
|
|
| weight = self.weight * self.attn_mask
|
| if self.positive:
|
| weight = weight.abs()
|
| if self.blockwise:
|
| weight = weight.expand((self.embed_dim, 1, self.kernel_size))
|
|
|
| output = F.conv1d(
|
| input, weight, self.bias, padding="same", groups=self.embed_dim
|
| )
|
|
|
| output = output.transpose(1, 2)
|
| output = output.reshape(leading_dims + [L, C])
|
| return output
|
|
|
| def extra_repr(self):
|
| return (
|
| f"{self.embed_dim}, kernel_size={self.kernel_size}, "
|
| f"causal={self.causal}, positive={self.positive}, "
|
| f"blockwise={self.blockwise}, bias={self.bias is not None}"
|
| )
|
|
|
|
|
| class LinearPoolLatent(nn.Module):
|
| """Learned linear pooling over a set of features.
|
|
|
| Shape:
|
| input: (*, L, C)
|
| output: (*, C)
|
| """
|
|
|
| def __init__(self, embed_dim: int, feat_size: int, positive: bool = False):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.feat_size = feat_size
|
| self.positive = positive
|
|
|
| self.weight = nn.Parameter(torch.empty(feat_size, embed_dim))
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| nn.init.trunc_normal_(self.weight, std=0.02)
|
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
| weight = self.weight
|
| if self.positive:
|
| weight = weight.abs()
|
| input = torch.sum(input * weight, dim=-2)
|
| return input
|
|
|
| def extra_repr(self):
|
| return f"{self.embed_dim}, feat_size={self.feat_size}, positive={self.positive}"
|
|
|
|
|
| class AttentionPoolLatent(nn.Module):
|
| def __init__(
|
| self,
|
| embed_dim: int,
|
| num_heads: int = 8,
|
| ):
|
| super().__init__()
|
| assert embed_dim % num_heads == 0
|
| self.embed_dim = embed_dim
|
| self.num_heads = num_heads
|
| self.query = nn.Parameter(torch.zeros(embed_dim))
|
| self.kv = nn.Linear(embed_dim, 2 * embed_dim)
|
| self.proj = nn.Linear(embed_dim, embed_dim)
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| nn.init.trunc_normal_(self.query, std=0.02)
|
|
|
| def forward(self, x: torch.Tensor):
|
| *leading_dims, L, C = x.shape
|
| x = x.reshape(-1, L, C)
|
| N = len(x)
|
| h = self.num_heads
|
|
|
|
|
| q = self.query.expand(N, 1, C)
|
| q = q.reshape(N, 1, h, C // h).transpose(1, 2)
|
|
|
|
|
| kv = self.kv(x)
|
| kv = kv.reshape(N, L, 2, h, C // h).permute(2, 0, 3, 1, 4)
|
| k, v = torch.unbind(kv, dim=0)
|
|
|
| x = F.scaled_dot_product_attention(q, k, v)
|
| x = x.reshape(N, C)
|
| x = self.proj(x)
|
|
|
| x = x.reshape(leading_dims + [C])
|
| return x
|
|
|
| class ConvLinear(nn.Module):
|
| def __init__(
|
| self,
|
| in_features: int,
|
| out_features: int,
|
| kernel_size: int = 11,
|
| causal: bool = False,
|
| ):
|
| super().__init__()
|
| self.conv = DepthConv1d(
|
| in_features,
|
| kernel_size=kernel_size,
|
| causal=causal,
|
| groups=in_features,
|
| )
|
| self.fc = nn.Linear(in_features, out_features)
|
|
|
| def forward(self, x: torch.Tensor):
|
|
|
| x = self.conv(x)
|
| x = self.fc(x)
|
| return x
|
|
|
|
|
| class LinearConv(nn.Module):
|
| def __init__(
|
| self,
|
| in_features: int,
|
| out_features: int,
|
| kernel_size: int = 11,
|
| causal: bool = False,
|
| positive: bool = False,
|
| blockwise: bool = False,
|
| ):
|
| super().__init__()
|
| self.fc = nn.Linear(in_features, out_features)
|
| self.conv = DepthConv1d(
|
| out_features,
|
| kernel_size=kernel_size,
|
| causal=causal,
|
| positive=positive,
|
| blockwise=blockwise,
|
| )
|
|
|
| def forward(self, x: torch.Tensor):
|
|
|
| x = self.fc(x)
|
| x = self.conv(x)
|
| return x
|
|
|
| class MultiSubjectConvLinearEncoder(nn.Module):
|
| weight: torch.Tensor
|
|
|
| def __init__(
|
| self,
|
| num_subjects: int = 4,
|
| feat_dims: tuple[int | tuple[int, int], ...] = (2048,),
|
| embed_dim: int = 256,
|
| target_dim: int = 1000,
|
| hidden_model: nn.Module | None = None,
|
| global_pool: Literal["avg", "linear", "attn"] = "avg",
|
| encoder_kernel_size: int = 33,
|
| decoder_kernel_size: int = 0,
|
| encoder_causal: bool = True,
|
| encoder_positive: bool = False,
|
| encoder_blockwise: bool = False,
|
| pool_num_heads: int = 4,
|
| with_shared_decoder: bool = True,
|
| with_subject_decoders: bool = True,
|
| ):
|
| assert with_shared_decoder or with_subject_decoders
|
|
|
| super().__init__()
|
| self.num_subjects = num_subjects
|
| self.global_pool = global_pool
|
|
|
|
|
| feat_dims = [(1, dim) if isinstance(dim, int) else dim for dim in feat_dims]
|
| total_feat_size = sum(dim[0] for dim in feat_dims)
|
|
|
| self.feat_embeds = nn.ModuleList(
|
| [
|
| _make_feat_embed(
|
| dim[1],
|
| embed_dim,
|
| kernel_size=encoder_kernel_size,
|
| causal=encoder_causal,
|
| positive=encoder_positive,
|
| blockwise=encoder_blockwise,
|
| )
|
| for dim in feat_dims
|
| ]
|
| )
|
|
|
| if global_pool == "avg":
|
| self.register_module("feat_pool", None)
|
| elif global_pool == "linear":
|
| self.feat_pool = LinearPoolLatent(embed_dim, total_feat_size)
|
| elif global_pool == "attn":
|
| self.feat_pool = AttentionPoolLatent(
|
| embed_dim, total_feat_size, num_heads=pool_num_heads
|
| )
|
| else:
|
| raise NotImplementedError(f"global_pool {global_pool} not implemented.")
|
|
|
| if hidden_model is not None:
|
| self.hidden_model = hidden_model
|
| else:
|
| self.register_module("hidden_model", None)
|
|
|
| if with_shared_decoder:
|
| self.shared_decoder = nn.Linear(embed_dim, target_dim)
|
| else:
|
| self.register_module("shared_decoder", None)
|
|
|
| if decoder_kernel_size > 1:
|
| decoder_linear = partial(ConvLinear, kernel_size=decoder_kernel_size)
|
| else:
|
| decoder_linear = nn.Linear
|
|
|
| if with_subject_decoders:
|
| self.subject_decoders = nn.ModuleList(
|
| [decoder_linear(embed_dim, target_dim) for _ in range(num_subjects)]
|
| )
|
| else:
|
| self.register_module("subject_decoders", None)
|
|
|
| self.apply(_init_weights)
|
|
|
| def forward(self, features: list[torch.Tensor]):
|
|
|
|
|
| embed_features: list[torch.Tensor] = []
|
| for feat, feat_embed in zip(features, self.feat_embeds):
|
|
|
| feat = feat[:, None] if feat.ndim == 3 else feat.transpose(1, 2)
|
|
|
|
|
| feat = feat_embed(feat)
|
| embed_features.append(feat)
|
|
|
| if self.global_pool == "avg":
|
| embed = sum(feat.mean(dim=1) for feat in embed_features)
|
| else:
|
| embed = torch.cat(embed_features, dim=1)
|
|
|
| embed = embed.transpose(1, 2)
|
| embed = self.feat_pool(embed)
|
|
|
| if self.hidden_model is not None:
|
| embed = self.hidden_model(embed)
|
|
|
| if self.shared_decoder is not None:
|
| shared_output = self.shared_decoder(embed)
|
| shared_output = shared_output[:, None].expand(-1, self.num_subjects, -1, -1)
|
| else:
|
| shared_output = 0.0
|
|
|
| if self.subject_decoders is not None:
|
| subject_output = torch.stack(
|
| [decoder(embed) for decoder in self.subject_decoders],
|
| dim=1,
|
| )
|
| else:
|
| subject_output = 0.0
|
|
|
| output = subject_output + shared_output
|
| return output
|
|
|
|
|
| def _make_feat_embed(
|
| feat_dim: int = 2048,
|
| embed_dim: int = 256,
|
| kernel_size: int = 33,
|
| causal: bool = True,
|
| positive: bool = False,
|
| blockwise: bool = False,
|
| ) -> nn.Module:
|
| if kernel_size > 1:
|
| embed = LinearConv(
|
| feat_dim,
|
| embed_dim,
|
| kernel_size=kernel_size,
|
| causal=causal,
|
| positive=positive,
|
| blockwise=blockwise,
|
| )
|
| else:
|
| embed = nn.Linear(feat_dim, embed_dim)
|
| return embed
|
|
|
|
|
| def _init_weights(m: nn.Module) -> None:
|
| if isinstance(m, (nn.Conv1d, nn.Linear, DepthConv1d)):
|
| nn.init.trunc_normal_(m.weight, std=0.02)
|
| if m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
| elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)) and m.elementwise_affine:
|
| nn.init.constant_(m.weight, 1.0)
|
| if hasattr(m, "bias") and m.bias is not None:
|
| nn.init.constant_(m.bias, 0)
|
|
|