| |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_scatter import scatter_sum |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def sequential_and(*tensors): |
| res = tensors[0] |
| for mat in tensors[1:]: |
| res = torch.logical_and(res, mat) |
| return res |
|
|
|
|
| def sequential_or(*tensors): |
| res = tensors[0] |
| for mat in tensors[1:]: |
| res = torch.logical_or(res, mat) |
| return res |
|
|
|
|
| def graph_to_batch(tensor, batch_id, padding_value=0, mask_is_pad=True): |
| ''' |
| :param tensor: [N, D1, D2, ...] |
| :param batch_id: [N] |
| :param mask_is_pad: 1 in the mask indicates padding if set to True |
| ''' |
| lengths = scatter_sum(torch.ones_like(batch_id), batch_id) |
| bs, max_n = lengths.shape[0], torch.max(lengths) |
| batch = torch.ones((bs, max_n, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) * padding_value |
| |
| pad_mask = torch.zeros((bs, max_n + 1), dtype=torch.long, device=tensor.device) |
| pad_mask[(torch.arange(bs, device=tensor.device), lengths)] = 1 |
| pad_mask = (torch.cumsum(pad_mask, dim=-1)[:, :-1]).bool() |
| data_mask = torch.logical_not(pad_mask) |
| |
| batch[data_mask] = tensor |
| mask = pad_mask if mask_is_pad else data_mask |
| return batch, mask |
|
|
|
|
| def variadic_arange(size): |
| """ |
| from https://torchdrug.ai/docs/_modules/torchdrug/layers/functional/functional.html#variadic_arange |
| |
| Return a 1-D tensor that contains integer intervals of variadic sizes. |
| This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``. |
| |
| Suppose there are :math:`N` intervals. |
| |
| Parameters: |
| size (LongTensor): size of intervals of shape :math:`(N,)` |
| """ |
| starts = size.cumsum(0) - size |
|
|
| range = torch.arange(size.sum(), device=size.device) |
| range = range - starts.repeat_interleave(size) |
| return range |
|
|
|
|
| def variadic_meshgrid(input1, size1, input2, size2): |
| """ |
| from https://torchdrug.ai/docs/_modules/torchdrug/layers/functional/functional.html#variadic_meshgrid |
| Compute the Cartesian product for two batches of sets with variadic sizes. |
| |
| Suppose there are :math:`N` sets in each input, |
| and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively. |
| |
| Parameters: |
| input1 (Tensor): input of shape :math:`(B_1, ...)` |
| size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)` |
| input2 (Tensor): input of shape :math:`(B_2, ...)` |
| size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)` |
| |
| Returns |
| (Tensor, Tensor): the first and the second elements in the Cartesian product |
| """ |
| grid_size = size1 * size2 |
| local_index = variadic_arange(grid_size) |
| local_inner_size = size2.repeat_interleave(grid_size) |
| offset1 = (size1.cumsum(0) - size1).repeat_interleave(grid_size) |
| offset2 = (size2.cumsum(0) - size2).repeat_interleave(grid_size) |
| index1 = torch.div(local_index, local_inner_size, rounding_mode="floor") + offset1 |
| index2 = local_index % local_inner_size + offset2 |
| return input1[index1], input2[index2] |
|
|
|
|
| @torch.no_grad() |
| def length_to_batch_id(S, lengths): |
| |
| batch_id = torch.zeros_like(S) |
| batch_id[torch.cumsum(lengths, dim=0)[:-1]] = 1 |
| batch_id.cumsum_(dim=0) |
| return batch_id |
|
|
|
|
| def scatter_sort(src: torch.Tensor, index: torch.Tensor, dim=0, descending=False, eps=1e-12): |
| ''' |
| from https://github.com/rusty1s/pytorch_scatter/issues/48 |
| WARN: the range between src.max() and src.min() should not be too wide for numerical stability |
| |
| reproducible |
| ''' |
| |
| |
| |
| |
|
|
| |
| src, src_perm = torch.sort(src, dim=dim, descending=descending) |
| index = index.take_along_dim(src_perm, dim=dim) |
| index, index_perm = torch.sort(index, dim=dim, stable=True) |
| src = src.take_along_dim(index_perm, dim=dim) |
| perm = src_perm.take_along_dim(index_perm, dim=0) |
| return src, perm |
|
|
|
|
| def scatter_topk(src: torch.Tensor, index: torch.Tensor, k: int, dim=0, largest=True): |
| indices = torch.arange(src.shape[dim], device=src.device) |
| src, perm = scatter_sort(src, index, dim, descending=largest) |
| index, indices = index[perm], indices[perm] |
| mask = torch.ones_like(index).bool() |
| mask[k:] = index[k:] != index[:-k] |
| return src[mask], indices[mask] |
|
|
|
|
| @torch.no_grad() |
| def knn_edges(all_edges, k_neighbors, X=None, atom_mask=None, given_dist=None): |
| ''' |
| :param all_edges: [2, E], (row, col) |
| :param X: [N, n_channel, 3], coordinates |
| :param atom_mask: [N, n_channel], 1 for having atom |
| :param given_dist: [E], given distance of edges |
| IMPORTANT: either given_dist should be given, or both X and atom_mask should be given |
| ''' |
| assert (given_dist is not None) or (X is not None and atom_mask is not None), \ |
| 'either given_dist should be given, or both X and atom_mask should be given' |
|
|
| |
| if given_dist is None: |
| row, col = all_edges |
| dist = torch.norm(X[row][:, :, None, :] - X[col][:, None, :, :], dim=-1) |
| dist_mask = atom_mask[row][:, :, None] & atom_mask[col][:, None, :] |
| dist = torch.where(dist_mask, dist, torch.ones_like(dist) * float('inf')) |
| dist, _ = dist.view(dist.shape[0], -1).min(axis=-1) |
| else: |
| dist = given_dist |
|
|
| |
| _, indices = scatter_topk(dist, row, k=k_neighbors, largest=False) |
| edges = torch.stack([all_edges[0][indices], all_edges[1][indices]], dim=0) |
| return edges |
|
|
|
|
| class EdgeConstructor: |
| def __init__(self, cor_idx, col_idx, atom_pos_pad_idx, rec_seg_id) -> None: |
| self.cor_idx, self.col_idx = cor_idx, col_idx |
| self.atom_pos_pad_idx = atom_pos_pad_idx |
| self.rec_seg_id = rec_seg_id |
|
|
| |
| self._reset_buffer() |
|
|
| def _reset_buffer(self): |
| self.row = None |
| self.col = None |
| self.row_global = None |
| self.col_global = None |
| self.row_seg = None |
| self.col_seg = None |
| self.offsets = None |
| self.max_n = None |
| self.gni2lni = None |
| self.not_global_edges = None |
|
|
| def get_batch_edges(self, batch_id): |
| |
| lengths = scatter_sum(torch.ones_like(batch_id), batch_id) |
| N, max_n = batch_id.shape[0], torch.max(lengths) |
| offsets = F.pad(torch.cumsum(lengths, dim=0)[:-1], pad=(1, 0), value=0) |
| |
| gni = torch.arange(N, device=batch_id.device) |
| gni2lni = gni - offsets[batch_id] |
|
|
| |
| |
| same_bid = torch.zeros(N, max_n, device=batch_id.device) |
| same_bid[(gni, lengths[batch_id] - 1)] = 1 |
| same_bid = 1 - torch.cumsum(same_bid, dim=-1) |
| |
| same_bid = F.pad(same_bid[:, :-1], pad=(1, 0), value=1) |
| same_bid[(gni, gni2lni)] = 0 |
| row, col = torch.nonzero(same_bid).T |
| col = col + offsets[batch_id[row]] |
| return (row, col), (offsets, max_n, gni2lni) |
|
|
| def _prepare(self, S, batch_id, segment_ids) -> None: |
| (row, col), (offsets, max_n, gni2lni) = self.get_batch_edges(batch_id) |
|
|
| |
| is_global = sequential_or(S == self.cor_idx, S == self.col_idx) |
| row_global, col_global = is_global[row], is_global[col] |
| not_global_edges = torch.logical_not(torch.logical_or(row_global, col_global)) |
| |
| |
| row_seg, col_seg = segment_ids[row], segment_ids[col] |
|
|
| |
| self.row, self.col = row, col |
| self.offsets, self.max_n, self.gni2lni = offsets, max_n, gni2lni |
| self.row_global, self.col_global = row_global, col_global |
| self.not_global_edges = not_global_edges |
| self.row_seg, self.col_seg = row_seg, col_seg |
|
|
| def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos): |
| row, col = self.row, self.col |
| |
| select_edges = torch.logical_and(self.row_seg == self.col_seg, self.not_global_edges) |
| ctx_all_row, ctx_all_col = row[select_edges], col[select_edges] |
| |
| inner_edges = _knn_edges( |
| X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T, |
| self.atom_pos_pad_idx, k_neighbors, |
| (self.offsets, batch_id, self.max_n, self.gni2lni)) |
| return inner_edges |
|
|
| def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos): |
| row, col = self.row, self.col |
| |
| select_edges = torch.logical_and(self.row_seg != self.col_seg, self.not_global_edges) |
| inter_all_row, inter_all_col = row[select_edges], col[select_edges] |
| outer_edges = _knn_edges( |
| X, atom_pos, torch.stack([inter_all_row, inter_all_col]).T, |
| self.atom_pos_pad_idx, k_neighbors, |
| (self.offsets, batch_id, self.max_n, self.gni2lni)) |
| return outer_edges |
|
|
| def _construct_global_edges(self): |
| row, col = self.row, self.col |
| |
| select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges)) |
| global_normal = torch.stack([row[select_edges], col[select_edges]]) |
| |
| select_edges = torch.logical_and(self.row_global, self.col_global) |
| global_global = torch.stack([row[select_edges], col[select_edges]]) |
| return global_normal, global_global |
|
|
| def _construct_seq_edges(self): |
| row, col = self.row, self.col |
| |
| select_edges = sequential_and( |
| torch.logical_or((row - col) == 1, (row - col) == -1), |
| self.not_global_edges, |
| self.row_seg != self.rec_seg_id |
| ) |
| seq_adj = torch.stack([row[select_edges], col[select_edges]]) |
| return seq_adj |
|
|
| @torch.no_grad() |
| def construct_edges(self, X, S, batch_id, k_neighbors, atom_pos, segment_ids): |
| ''' |
| Memory efficient with complexity of O(Nn) where n is the largest number of nodes in the batch |
| ''' |
| |
| self._prepare(S, batch_id, segment_ids) |
|
|
| ctx_edges, inter_edges = [], [] |
|
|
| |
| inner_edges = self._construct_inner_edges(X, batch_id, k_neighbors, atom_pos) |
| |
| global_normal, global_global = self._construct_global_edges() |
| |
| seq_edges = self._construct_seq_edges() |
|
|
| |
| ctx_edges = torch.cat([inner_edges, global_normal, global_global, seq_edges], dim=1) |
|
|
| |
| inter_edges = self._construct_outer_edges(X, batch_id, k_neighbors, atom_pos) |
|
|
| self._reset_buffer() |
| return ctx_edges, inter_edges |
|
|
|
|
| class GMEdgeConstructor(EdgeConstructor): |
| ''' |
| Edge constructor for graph matching (kNN internel edges and all bipartite edges) |
| ''' |
| def _construct_inner_edges(self, X, batch_id, k_neighbors, atom_pos): |
| row, col = self.row, self.col |
| |
| row_is_rec = self.row_seg == self.rec_seg_id |
| col_is_rec = self.col_seg == self.rec_seg_id |
| select_edges = torch.logical_and(row_is_rec == col_is_rec, self.not_global_edges) |
| ctx_all_row, ctx_all_col = row[select_edges], col[select_edges] |
| |
| inner_edges = _knn_edges( |
| X, atom_pos, torch.stack([ctx_all_row, ctx_all_col]).T, |
| self.atom_pos_pad_idx, k_neighbors, |
| (self.offsets, batch_id, self.max_n, self.gni2lni)) |
| return inner_edges |
|
|
| def _construct_global_edges(self): |
| row, col = self.row, self.col |
| |
| select_edges = torch.logical_and(self.row_seg == self.col_seg, torch.logical_not(self.not_global_edges)) |
| global_normal = torch.stack([row[select_edges], col[select_edges]]) |
| |
| select_edges = sequential_and(self.row_global, self.col_global) |
| global_global = torch.stack([row[select_edges], col[select_edges]]) |
| return global_normal, global_global |
|
|
| def _construct_outer_edges(self, X, batch_id, k_neighbors, atom_pos): |
| row, col = self.row, self.col |
| |
| row_is_rec = self.row_seg == self.rec_seg_id |
| col_is_rec = self.col_seg == self.rec_seg_id |
| select_edges = torch.logical_and(row_is_rec != col_is_rec, self.not_global_edges) |
| inter_all_row, inter_all_col = row[select_edges], col[select_edges] |
| return torch.stack([inter_all_row, inter_all_col]) |
|
|
|
|
| class SinusoidalPositionEmbedding(nn.Module): |
| """ |
| Sin-Cos Positional Embedding |
| """ |
| def __init__(self, output_dim): |
| super(SinusoidalPositionEmbedding, self).__init__() |
| self.output_dim = output_dim |
|
|
| def forward(self, position_ids): |
| device = position_ids.device |
| position_ids = position_ids[None] |
| indices = torch.arange(self.output_dim // 2, device=device, dtype=torch.float) |
| indices = torch.pow(10000.0, -2 * indices / self.output_dim) |
| embeddings = torch.einsum('bn,d->bnd', position_ids, indices) |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) |
| embeddings = embeddings.reshape(-1, self.output_dim) |
| return embeddings |