| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| import torch.nn.functional as F |
|
|
|
|
| class Conv(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False): |
| super().__init__() |
| self.cnn_type = cnn_type |
| self.slice_seq_len = 17 |
| |
| if cnn_type == "2d": |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) |
| if cnn_type == "3d": |
| if temporal_down == False: |
| stride = (1, stride, stride) |
| else: |
| stride = (stride, stride, stride) |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0) |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| self.padding = ( |
| kernel_size[0] - 1 + causal_offset, |
| padding, |
| padding |
| ) |
| self.causal_offset = causal_offset |
| self.stride = stride |
| self.kernel_size = kernel_size |
| |
| def forward(self, x): |
| if self.cnn_type == "2d": |
| if x.ndim == 5: |
| B, C, T, H, W = x.shape |
| x = rearrange(x, "B C T H W -> (B T) C H W") |
| x = self.conv(x) |
| x = rearrange(x, "(B T) C H W -> B C T H W", T=T) |
| return x |
| else: |
| return self.conv(x) |
| if self.cnn_type == "3d": |
| assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported" |
| xs = [] |
| for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1): |
| st = i |
| en = min(i+self.slice_seq_len, x.shape[2]) |
| _x = x[:,:,st:en,:,:] |
| if i == 0: |
| _x = F.pad(_x, (self.padding[2], self.padding[2], |
| self.padding[1], self.padding[1], |
| self.padding[0], 0)) |
| else: |
| padding_0 = self.kernel_size[0] - 1 |
| _x = F.pad(_x, (self.padding[2], self.padding[2], |
| self.padding[1], self.padding[1], |
| padding_0, 0)) |
| _x[:,:,:padding_0, |
| self.padding[1]:_x.shape[-2]-self.padding[1], |
| self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:] |
| _x = self.conv(_x) |
| xs.append(_x) |
| try: |
| x = torch.cat(xs, dim=2) |
| except: |
| device = x.device |
| del x |
| xs = [_x.cpu().pin_memory() for _x in xs] |
| torch.cuda.empty_cache() |
| x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) |
| return x |