| import random |
| import math |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class Pooler(nn.Module): |
| def __init__(self, dim_in, dim_out, pool_out_size): |
| super().__init__() |
| if not isinstance(pool_out_size, (list, tuple)): |
| pool_out_size = [pool_out_size] |
|
|
| self.pool_out_size = pool_out_size |
| print("pool_out_size: {}".format(self.pool_out_size)) |
|
|
| self.mlp = nn.Sequential( |
| nn.Linear(dim_in, dim_out), |
| nn.GELU(), |
| nn.Linear(dim_out, dim_out) |
| ) |
| |
| def forward(self, x): |
| """ |
| Args: |
| x (torch.Tensor): image features |
| shape (b, T, F, v, D) |
| Returns: |
| shape (b, T, n, D) where n is self.num_latents |
| """ |
| b, t, f, v, d = x.shape |
| s = int(math.sqrt(v -1)) |
| assert t == 1 and f == 1 |
| x = x[:, :, :, 1:, :] |
| x_in = x.reshape(b, t, f, s, s, d) |
|
|
| pool_out_size = random.choice(self.pool_out_size) |
| if '+' in pool_out_size: |
| pool_out_size_list = [int(p) for p in pool_out_size.split('+')] |
| else: |
| pool_out_size_list = [int(pool_out_size)] |
| pool_out_size_list.sort(reverse=True) |
|
|
| x_out = [] |
| for pool_out_size in pool_out_size_list: |
| x = x_in.reshape(b, t, f, pool_out_size, s//pool_out_size, pool_out_size, s//pool_out_size, d) |
| x = x.permute([0, 1, 2, 3, 5, 7, 4, 6]).reshape(b, t, f, pool_out_size * pool_out_size, d, -1).mean(-1) |
| x = self.mlp(x) |
| x = x.flatten(0, 2) |
| x_out.append(x) |
| x_out = torch.cat(x_out, dim=-2) |
|
|
| return x_out.unsqueeze(1) |
|
|