| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_scatter import scatter_sum |
| from torch_geometric.nn import radius_graph, knn_graph |
| from models.add_module.common import GaussianSmearing, MLP, batch_hybrid_edge_connection, NONLINEARITIES |
|
|
|
|
| class EnBaseLayer(nn.Module): |
| def __init__(self, hidden_dim, num_r_gaussian, update_x=True, act_fn='silu', norm=False): |
| super().__init__() |
| self.r_min = 0. |
| self.r_max = 10. |
| self.hidden_dim = hidden_dim |
| self.num_r_gaussian = num_r_gaussian |
| self.update_x = update_x |
| self.act_fn = act_fn |
| self.norm = norm |
| if num_r_gaussian > 1: |
| self.distance_expansion = GaussianSmearing(self.r_min, self.r_max, num_gaussians=num_r_gaussian) |
| self.edge_mlp = MLP(2 * hidden_dim + num_r_gaussian, hidden_dim, hidden_dim, |
| num_layer=2, norm=norm, act_fn=act_fn, act_last=True) |
| self.edge_inf = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) |
| if self.update_x: |
| |
| x_mlp = [nn.Linear(hidden_dim, hidden_dim), NONLINEARITIES[act_fn]] |
| layer = nn.Linear(hidden_dim, 1, bias=False) |
| torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) |
| x_mlp.append(layer) |
| x_mlp.append(nn.Tanh()) |
| self.x_mlp = nn.Sequential(*x_mlp) |
|
|
| self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, num_layer=2, norm=norm, act_fn=act_fn) |
|
|
| def forward(self, h, x, edge_index, edge_attr=None): |
| src, dst = edge_index |
| hi, hj = h[dst], h[src] |
| |
| rel_x = x[dst] - x[src] |
| d_sq = torch.sum(rel_x ** 2, -1, keepdim=True) |
| if self.num_r_gaussian > 1: |
| d_feat = self.distance_expansion(torch.sqrt(d_sq + 1e-8)) |
| else: |
| d_feat = d_sq |
|
|
| if edge_attr is not None: |
| edge_feat = torch.cat([d_feat, edge_attr], -1) |
| else: |
| edge_feat = d_feat |
|
|
| mij = self.edge_mlp(torch.cat([hi, hj, edge_feat], -1)) |
| eij = self.edge_inf(mij) |
| mi = scatter_sum(mij * eij, dst, dim=0, dim_size=h.shape[0]) |
|
|
| |
| h = h + self.node_mlp(torch.cat([mi, h], -1)) |
| if self.update_x: |
| |
| xi, xj = x[dst], x[src] |
| |
| delta_x = scatter_sum((xi - xj) / (torch.sqrt(d_sq + 1e-8) + 1) * self.x_mlp(mij), dst, dim=0) |
| x = x + delta_x |
|
|
| return h, x |
|
|
|
|
| class EGNN(nn.Module): |
| def __init__(self, num_layers=2, hidden_dim=128, num_r_gaussian=20, k=32, cutoff=20.0, cutoff_mode='knn', |
| update_x=True, act_fn='silu', norm=False): |
| super().__init__() |
| |
| self.num_layers = num_layers |
| self.hidden_dim = hidden_dim |
| self.num_r_gaussian = num_r_gaussian |
| self.update_x = update_x |
| self.act_fn = act_fn |
| self.norm = norm |
| self.k = k |
| |
| self.cutoff = cutoff |
| self.cutoff_mode = cutoff_mode |
| self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=num_r_gaussian) |
| self.net = self._build_network() |
|
|
| def _build_network(self): |
| |
| layers = [] |
| for l_idx in range(self.num_layers): |
| layer = EnBaseLayer(self.hidden_dim, self.num_r_gaussian, |
| update_x=self.update_x, act_fn=self.act_fn, norm=self.norm) |
| layers.append(layer) |
| return nn.ModuleList(layers) |
|
|
| def _connect_edge(self, x, batch, orgshape): |
| |
| |
| if self.cutoff_mode == 'knn': |
| edge_index = knn_graph(x, k=self.k, batch=batch, flow='source_to_target') |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| elif self.cutoff_mode == 'hybrid': |
| pass |
| |
| |
| else: |
| raise ValueError(f'Not supported cutoff mode: {self.cutoff_mode}') |
| return edge_index |
|
|
| def forward(self, node_embed, trans_t): |
|
|
| B,L = node_embed.shape[:2] |
|
|
| x = trans_t.reshape(B*L,-1) |
| h = node_embed.reshape(B*L,-1) |
|
|
| batch = [] |
| for idx in range(B): |
| batch += [idx] * L |
| batch = torch.tensor(batch,device = node_embed.device) |
|
|
| all_x = [x] |
| all_h = [h] |
| for l_idx, layer in enumerate(self.net): |
| edge_index = self._connect_edge(x, batch, orgshape = (B,L)) |
| h, x = layer(h, x, edge_index) |
| all_x.append(x) |
| all_h.append(h) |
| outputs = {'x': x, 'h': h} |
| return x.reshape(B,L,-1), h.reshape(B,L,-1) |
|
|