Spaces:
Build error
Build error
| # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from nncore.nn import MODELS | |
| class Permute(nn.Module): | |
| def __init__(self): | |
| super(Permute, self).__init__() | |
| def forward(self, x): | |
| return x.transpose(-1, -2) | |
| class ConvPyramid(nn.Module): | |
| def __init__(self, dims, strides): | |
| super(ConvPyramid, self).__init__() | |
| self.blocks = nn.ModuleList() | |
| for s in strides: | |
| p = int(math.log2(s)) | |
| if p == 0: | |
| layers = nn.ReLU(inplace=True) | |
| else: | |
| layers = nn.Sequential() | |
| conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d | |
| for _ in range(abs(p)): | |
| layers.extend([ | |
| Permute(), | |
| conv_cls(dims, dims, 2, stride=2), | |
| Permute(), | |
| nn.LayerNorm(dims), | |
| nn.ReLU(inplace=True) | |
| ]) | |
| self.blocks.append(layers) | |
| self.strides = strides | |
| def forward(self, x, mask, return_mask=False): | |
| pymid, pymid_msk = [], [] | |
| for s, blk in zip(self.strides, self.blocks): | |
| if x.size(1) < s: | |
| continue | |
| pymid.append(blk(x)) | |
| if return_mask: | |
| if s > 1: | |
| msk = F.max_pool1d(mask.float(), s, stride=s).long() | |
| elif s < 1: | |
| msk = mask.repeat_interleave(int(1 / s), dim=1) | |
| else: | |
| msk = mask | |
| pymid_msk.append(msk) | |
| return pymid, pymid_msk | |
| class AdaPooling(nn.Module): | |
| def __init__(self, dims): | |
| super(AdaPooling, self).__init__() | |
| self.att = nn.Linear(dims, 1, bias=False) | |
| def forward(self, x, mask): | |
| a = self.att(x) + torch.where(mask.unsqueeze(2) == 1, .0, float('-inf')) | |
| a = a.softmax(dim=1) | |
| x = torch.matmul(x.transpose(1, 2), a) | |
| x = x.squeeze(2).unsqueeze(1) | |
| return x | |
| class ConvHead(nn.Module): | |
| def __init__(self, dims, out_dims, kernal_size=3): | |
| super(ConvHead, self).__init__() | |
| # yapf:disable | |
| self.module = nn.Sequential( | |
| Permute(), | |
| nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2), | |
| nn.ReLU(inplace=True), | |
| nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2), | |
| Permute()) | |
| # yapf:enable | |
| def forward(self, x): | |
| return self.module(x) | |