Spaces:
Build error
Build error
| # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License. | |
| import torch | |
| import torch.nn as nn | |
| from nncore.nn import MODELS, build_model | |
| class R2Block(nn.Module): | |
| def __init__(self, | |
| dims, | |
| in_dims, | |
| k=4, | |
| dropout=0.5, | |
| use_tef=True, | |
| pos_cfg=None, | |
| tem_cfg=None): | |
| super(R2Block, self).__init__() | |
| # yapf:disable | |
| self.video_map = nn.Sequential( | |
| nn.LayerNorm((in_dims[0] + 2) if use_tef else in_dims[0]), | |
| nn.Dropout(dropout), | |
| nn.Linear((in_dims[0] + 2) if use_tef else in_dims[0], dims), | |
| nn.ReLU(inplace=True), | |
| nn.LayerNorm(dims), | |
| nn.Dropout(dropout), | |
| nn.Linear(dims, dims)) | |
| self.query_map = nn.Sequential( | |
| nn.LayerNorm(in_dims[1]), | |
| nn.Dropout(dropout), | |
| nn.Linear(in_dims[1], dims), | |
| nn.ReLU(inplace=True), | |
| nn.LayerNorm(dims), | |
| nn.Dropout(dropout), | |
| nn.Linear(dims, dims)) | |
| # yapf:enable | |
| if k > 1: | |
| self.gate = nn.Parameter(torch.zeros([k - 1])) | |
| self.v_map = nn.Linear(dims, dims) | |
| self.q_map = nn.Linear(dims, dims) | |
| self.scale = nn.Parameter(torch.zeros([k])) | |
| self.pos = build_model(pos_cfg, dims=dims) | |
| self.tem = build_model(tem_cfg, dims=dims) | |
| self.dims = dims | |
| self.in_dims = in_dims | |
| self.k = k | |
| self.dropout = dropout | |
| self.use_tef = use_tef | |
| def forward(self, video_emb, query_emb, video_msk, query_msk): | |
| video_emb = video_emb[-self.k:] | |
| query_emb = query_emb[-self.k:] | |
| _, b, t, p, _ = video_emb.size() | |
| if self.use_tef: | |
| tef_s = torch.arange(0, 1, 1 / t, device=video_emb.device) | |
| tef_e = tef_s + 1.0 / t | |
| tef = torch.stack((tef_s, tef_e), dim=1) | |
| tef = tef.unsqueeze(1).unsqueeze(0).unsqueeze(0).repeat(self.k, b, 1, p, 1) | |
| video_emb = torch.cat((video_emb, tef[:, :, :video_emb.size(2)]), dim=-1) | |
| coll_v, coll_q, last = [], [], None | |
| for i in range(self.k - 1, -1, -1): | |
| v_emb = self.video_map(video_emb[i]) # B * T * P * C | |
| q_emb = self.query_map(query_emb[i]) # B * L * C | |
| coll_v.append(v_emb[:, :, 0]) | |
| coll_q.append(q_emb) | |
| v_pool = v_emb.view(b * t, -1, self.dims) # BT * P * C | |
| q_pool = q_emb.repeat_interleave(t, dim=0) # BT * L * C | |
| v_pool_map = self.v_map(v_pool) # BT * P * C | |
| q_pool_map = self.q_map(q_pool) # BT * L * C | |
| att = torch.bmm(q_pool_map, v_pool_map.transpose(1, 2)) / self.dims**0.5 | |
| att = att.softmax(-1) # BT * L * P | |
| o_pool = torch.bmm(att, v_pool) + q_pool # BT * L * C | |
| o_pool = o_pool.amax(dim=1, keepdim=True) # BT * 1 * C | |
| v_emb = v_pool[:, 0, None] + o_pool * self.scale[i].tanh() | |
| v_emb = v_emb.view(b, t, self.dims) # B * T * C | |
| if i < self.k - 1: | |
| gate = self.gate[i].sigmoid() | |
| v_emb = gate * v_emb + (1 - gate) * last | |
| v_pe = self.pos(v_emb) | |
| last = self.tem(v_emb, q_emb, q_pe=v_pe, q_mask=video_msk, k_mask=query_msk) | |
| return last, q_emb, coll_v, coll_q | |