flow-matching / src /medarc_architecture.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from typing import Literal
class DepthConv1d(nn.Module):
"""Depthwise conv1d.
Args:
causal: use a causal convolution mask.
positive: constrain kernel to be non-negative.
blockwise: single kernel shared across all channels.
Shape:
input: (*, L, C)
output: (*, L, C)
"""
attn_mask: torch.Tensor
def __init__(
self,
embed_dim: int,
kernel_size: int,
causal: bool = False,
positive: bool = False,
blockwise: bool = False,
bias: bool = True,
):
assert not causal or kernel_size % 2 == 1, "causal conv requires odd kernel"
super().__init__()
self.embed_dim = embed_dim
self.kernel_size = kernel_size
self.causal = causal
self.positive = positive
self.blockwise = blockwise
if blockwise:
weight_shape = (1, 1, kernel_size)
else:
weight_shape = (embed_dim, 1, kernel_size)
self.weight = nn.Parameter(torch.empty(weight_shape))
if bias:
self.bias = nn.Parameter(torch.empty(embed_dim))
else:
self.register_parameter("bias", None)
attn_mask = torch.ones(kernel_size)
if self.causal:
attn_mask[kernel_size // 2 + 1 :] = 0.0
self.register_buffer("attn_mask", attn_mask)
self.init_weights()
def init_weights(self):
nn.init.trunc_normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# input: (*, L, C)
# output: (*, L, C)
*leading_dims, L, C = input.shape
assert C == self.embed_dim
# (*, L, C) -> (N, C, L)
input = input.reshape(-1, L, C).transpose(1, 2)
weight = self.weight * self.attn_mask
if self.positive:
weight = weight.abs()
if self.blockwise:
weight = weight.expand((self.embed_dim, 1, self.kernel_size))
output = F.conv1d(
input, weight, self.bias, padding="same", groups=self.embed_dim
)
output = output.transpose(1, 2)
output = output.reshape(leading_dims + [L, C])
return output
def extra_repr(self):
return (
f"{self.embed_dim}, kernel_size={self.kernel_size}, "
f"causal={self.causal}, positive={self.positive}, "
f"blockwise={self.blockwise}, bias={self.bias is not None}"
)
class LinearPoolLatent(nn.Module):
"""Learned linear pooling over a set of features.
Shape:
input: (*, L, C)
output: (*, C)
"""
def __init__(self, embed_dim: int, feat_size: int, positive: bool = False):
super().__init__()
self.embed_dim = embed_dim
self.feat_size = feat_size
self.positive = positive
self.weight = nn.Parameter(torch.empty(feat_size, embed_dim))
self.init_weights()
def init_weights(self):
nn.init.trunc_normal_(self.weight, std=0.02)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# input: (*, L, C)
# output: (*, C)
weight = self.weight
if self.positive:
weight = weight.abs()
input = torch.sum(input * weight, dim=-2)
return input
def extra_repr(self):
return f"{self.embed_dim}, feat_size={self.feat_size}, positive={self.positive}"
class AttentionPoolLatent(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int = 8,
):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(embed_dim))
self.kv = nn.Linear(embed_dim, 2 * embed_dim)
self.proj = nn.Linear(embed_dim, embed_dim)
self.init_weights()
def init_weights(self):
nn.init.trunc_normal_(self.query, std=0.02)
def forward(self, x: torch.Tensor):
*leading_dims, L, C = x.shape
x = x.reshape(-1, L, C)
N = len(x)
h = self.num_heads
# fixed learned query
q = self.query.expand(N, 1, C)
q = q.reshape(N, 1, h, C // h).transpose(1, 2) # [N, h, 1, C]
# keys, values for each input feature map
kv = self.kv(x)
kv = kv.reshape(N, L, 2, h, C // h).permute(2, 0, 3, 1, 4) # [2, N, h, L, C]
k, v = torch.unbind(kv, dim=0)
x = F.scaled_dot_product_attention(q, k, v) # [N, h, 1, C]
x = x.reshape(N, C)
x = self.proj(x)
x = x.reshape(leading_dims + [C])
return x
class ConvLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
kernel_size: int = 11,
causal: bool = False,
):
super().__init__()
self.conv = DepthConv1d(
in_features,
kernel_size=kernel_size,
causal=causal,
groups=in_features,
)
self.fc = nn.Linear(in_features, out_features)
def forward(self, x: torch.Tensor):
# x: (N, L, C)
x = self.conv(x)
x = self.fc(x)
return x
class LinearConv(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
kernel_size: int = 11,
causal: bool = False,
positive: bool = False,
blockwise: bool = False,
):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.conv = DepthConv1d(
out_features,
kernel_size=kernel_size,
causal=causal,
positive=positive,
blockwise=blockwise,
)
def forward(self, x: torch.Tensor):
# x: (N, L, C)
x = self.fc(x)
x = self.conv(x)
return x
class MultiSubjectConvLinearEncoder(nn.Module):
weight: torch.Tensor
def __init__(
self,
num_subjects: int = 4,
feat_dims: tuple[int | tuple[int, int], ...] = (2048,),
embed_dim: int = 256,
target_dim: int = 1000,
hidden_model: nn.Module | None = None,
global_pool: Literal["avg", "linear", "attn"] = "avg",
encoder_kernel_size: int = 33,
decoder_kernel_size: int = 0,
encoder_causal: bool = True,
encoder_positive: bool = False,
encoder_blockwise: bool = False,
pool_num_heads: int = 4,
with_shared_decoder: bool = True,
with_subject_decoders: bool = True,
):
assert with_shared_decoder or with_subject_decoders
super().__init__()
self.num_subjects = num_subjects
self.global_pool = global_pool
# list of (nfeats, dim)
feat_dims = [(1, dim) if isinstance(dim, int) else dim for dim in feat_dims]
total_feat_size = sum(dim[0] for dim in feat_dims)
self.feat_embeds = nn.ModuleList(
[
_make_feat_embed(
dim[1],
embed_dim,
kernel_size=encoder_kernel_size,
causal=encoder_causal,
positive=encoder_positive,
blockwise=encoder_blockwise,
)
for dim in feat_dims
]
)
if global_pool == "avg":
self.register_module("feat_pool", None)
elif global_pool == "linear":
self.feat_pool = LinearPoolLatent(embed_dim, total_feat_size)
elif global_pool == "attn":
self.feat_pool = AttentionPoolLatent(
embed_dim, total_feat_size, num_heads=pool_num_heads
)
else:
raise NotImplementedError(f"global_pool {global_pool} not implemented.")
if hidden_model is not None:
self.hidden_model = hidden_model
else:
self.register_module("hidden_model", None)
if with_shared_decoder:
self.shared_decoder = nn.Linear(embed_dim, target_dim)
else:
self.register_module("shared_decoder", None)
if decoder_kernel_size > 1:
decoder_linear = partial(ConvLinear, kernel_size=decoder_kernel_size)
else:
decoder_linear = nn.Linear
if with_subject_decoders:
self.subject_decoders = nn.ModuleList(
[decoder_linear(embed_dim, target_dim) for _ in range(num_subjects)]
)
else:
self.register_module("subject_decoders", None)
self.apply(_init_weights)
def forward(self, features: list[torch.Tensor]):
# features: list of (N, T, D_i) or (N, T, L_i, D_i)
embed_features: list[torch.Tensor] = []
for feat, feat_embed in zip(features, self.feat_embeds):
# view as (N, L_i, T, D_i)
feat = feat[:, None] if feat.ndim == 3 else feat.transpose(1, 2)
# project to (N, L, T, d)
feat = feat_embed(feat)
embed_features.append(feat)
if self.global_pool == "avg":
embed = sum(feat.mean(dim=1) for feat in embed_features)
else:
embed = torch.cat(embed_features, dim=1)
# (N, L, T, d) -> (N, T, L, d)
embed = embed.transpose(1, 2)
embed = self.feat_pool(embed)
if self.hidden_model is not None:
embed = self.hidden_model(embed)
if self.shared_decoder is not None:
shared_output = self.shared_decoder(embed)
shared_output = shared_output[:, None].expand(-1, self.num_subjects, -1, -1)
else:
shared_output = 0.0
if self.subject_decoders is not None:
subject_output = torch.stack(
[decoder(embed) for decoder in self.subject_decoders],
dim=1,
)
else:
subject_output = 0.0
output = subject_output + shared_output
return output
def _make_feat_embed(
feat_dim: int = 2048,
embed_dim: int = 256,
kernel_size: int = 33,
causal: bool = True,
positive: bool = False,
blockwise: bool = False,
) -> nn.Module:
if kernel_size > 1:
embed = LinearConv(
feat_dim,
embed_dim,
kernel_size=kernel_size,
causal=causal,
positive=positive,
blockwise=blockwise,
)
else:
embed = nn.Linear(feat_dim, embed_dim)
return embed
def _init_weights(m: nn.Module) -> None:
if isinstance(m, (nn.Conv1d, nn.Linear, DepthConv1d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.RMSNorm)) and m.elementwise_affine:
nn.init.constant_(m.weight, 1.0)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)