| import math |
| from math import pi as PI |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.parallel |
| import torch.utils.data |
| import torch_geometric.transforms as T |
| from torch.nn import ModuleList, Parameter |
| from torch_geometric.nn import HANConv, HEATConv, HGTConv, Linear |
| from torch_geometric.nn.conv import MessagePassing |
| from torch_geometric.nn.dense.linear import Linear |
| |
| from torch_geometric.nn.inits import glorot, zeros |
| from torch_geometric.utils import softmax |
| from torch_scatter import scatter |
|
|
| from util import get_angle, get_theta, triplets |
|
|
| class Smodel(nn.Module): |
| def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,share='0',batchnorm="True"): |
| super(Smodel,self).__init__() |
| self.training=True |
| self.h_channel = h_channel |
| self.input_featuresize=input_featuresize |
| self.localdepth = localdepth |
| self.num_interactions=num_interactions |
| self.finaldepth=finaldepth |
| self.batchnorm = batchnorm |
| self.activation=nn.ReLU() |
| self.att = Parameter(torch.ones(4),requires_grad=True) |
|
|
| num_gaussians=(1,1,1) |
| self.mlp_geo = ModuleList() |
| for i in range(self.localdepth): |
| if i == 0: |
| self.mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
| else: |
| self.mlp_geo.append(Linear(h_channel, h_channel)) |
| if self.batchnorm == "True": |
| self.mlp_geo.append(nn.BatchNorm1d(h_channel)) |
| self.mlp_geo.append(self.activation) |
| |
| self.mlp_geo_backup = ModuleList() |
| for i in range(self.localdepth): |
| if i == 0: |
| self.mlp_geo_backup.append(Linear(4, h_channel)) |
| else: |
| self.mlp_geo_backup.append(Linear(h_channel, h_channel)) |
| if self.batchnorm == "True": |
| self.mlp_geo_backup.append(nn.BatchNorm1d(h_channel)) |
| self.mlp_geo_backup.append(self.activation) |
| self.translinear=Linear(input_featuresize+1, self.h_channel) |
| self.interactions= ModuleList() |
| for i in range(self.num_interactions): |
| block = SPNN( |
| in_ch=self.input_featuresize, |
| hidden_channels=self.h_channel, |
| activation=self.activation, |
| finaldepth=self.finaldepth, |
| batchnorm=self.batchnorm, |
| num_input_geofeature=self.h_channel |
| ) |
| self.interactions.append(block) |
| self.reset_parameters() |
| def reset_parameters(self): |
| for lin in self.mlp_geo: |
| if isinstance(lin, Linear): |
| torch.nn.init.xavier_uniform_(lin.weight) |
| lin.bias.data.fill_(0) |
| for i in (self.interactions): |
| i.reset_parameters() |
|
|
| def single_forward(self, input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): |
| if edge_rep: |
| i, j, k = edge_index_2rd |
| edge_index1,edge_index2= edge_index |
| edge_index_all=torch.cat([edge_index1,edge_index2],1) |
| distance_ij=(coords[j] - coords[i]).norm(p=2, dim=1) |
| distance_jk=(coords[j] - coords[k]).norm(p=2, dim=1) |
| theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) |
| geo_encoding_1st=distance_ij[:,None] |
| geo_encoding=torch.cat([geo_encoding_1st,distance_jk[:,None],theta_ijk[:,None]],dim=-1) |
| else: |
| coords_j = coords[edge_index[0]] |
| coords_i = coords[edge_index[1]] |
| geo_encoding=torch.cat([coords_j,coords_i],dim=-1) |
| if edge_rep: |
| for lin in self.mlp_geo: |
| geo_encoding=lin(geo_encoding) |
| else: |
| for lin in self.mlp_geo_backup: |
| geo_encoding=lin(geo_encoding) |
| geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) |
| node_feature= input_feature |
| node_feature_list=[] |
| for interaction in self.interactions: |
| node_feature = interaction(node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,self.att) |
| node_feature_list.append(node_feature) |
| return node_feature_list |
| def forward(self, input_feature, coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep): |
| output=self.single_forward(input_feature,coords,edge_index,edge_index_2rd, edx_jk, edx_ij,batch,num_edge_inside,edge_rep) |
| return output |
| |
| class SPNN(torch.nn.Module): |
| def __init__( |
| self, |
| in_ch, |
| hidden_channels, |
| activation=torch.nn.ReLU(), |
| finaldepth=3, |
| batchnorm="True", |
| num_input_geofeature=13 |
| ): |
| super(SPNN, self).__init__() |
| self.activation = activation |
| self.finaldepth = finaldepth |
| self.batchnorm = batchnorm |
| self.num_input_geofeature=num_input_geofeature |
| |
| self.WMLP_list = ModuleList() |
| for _ in range(4): |
| WMLP = ModuleList() |
| for i in range(self.finaldepth + 1): |
| if i == 0: |
| WMLP.append(Linear(hidden_channels*3+num_input_geofeature, hidden_channels)) |
| else: |
| WMLP.append(Linear(hidden_channels, hidden_channels)) |
| if self.batchnorm == "True": |
| WMLP.append(nn.BatchNorm1d(hidden_channels)) |
| WMLP.append(self.activation) |
| self.WMLP_list.append(WMLP) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for mlp in self.WMLP_list: |
| for lin in mlp: |
| if isinstance(lin, Linear): |
| torch.nn.init.xavier_uniform_(lin.weight) |
| lin.bias.data.fill_(0) |
| def forward(self, node_feature,geo_encoding,edge_index_2rd,edx_jk,edx_ij,num_edge_inside,att): |
| i,j,k = edge_index_2rd |
| if node_feature is None: |
| concatenated_vector = geo_encoding |
| else: |
| node_attr_0st = node_feature[i] |
| node_attr_1st = node_feature[j] |
| node_attr_2 = node_feature[k] |
| concatenated_vector = torch.cat( |
| [ |
| node_attr_0st, |
| node_attr_1st,node_attr_2, |
| geo_encoding, |
| ], |
| dim=-1, |
| ) |
| x_i = concatenated_vector |
| |
| edge1_edge1_mask = (edx_ij < num_edge_inside) & (edx_jk < num_edge_inside) |
| edge1_edge2_mask = (edx_ij < num_edge_inside) & (edx_jk >= num_edge_inside) |
| edge2_edge1_mask = (edx_ij >= num_edge_inside) & (edx_jk < num_edge_inside) |
| edge2_edge2_mask = (edx_ij >= num_edge_inside) & (edx_jk >= num_edge_inside) |
| masks=[edge1_edge1_mask,edge1_edge2_mask,edge2_edge1_mask,edge2_edge2_mask] |
| |
| x_output=torch.zeros(x_i.shape[0],self.WMLP_list[0][0].weight.shape[0],device=x_i.device) |
| for index in range(4): |
| WMLP=self.WMLP_list[index] |
| x=x_i[masks[index]] |
| for lin in WMLP: |
| x=lin(x) |
| x = F.leaky_relu(x)*att[index] |
| x_output[masks[index]]+=x |
| |
| out_feature = scatter(x_output, i, dim=0, reduce='add') |
| return out_feature |
|
|
| class HGT(torch.nn.Module): |
| def __init__(self, hidden_channels, out_channels, num_heads, num_layers): |
| super().__init__() |
|
|
| self.lin_dict = torch.nn.ModuleDict() |
| for node_type in ["vertices"]: |
| self.lin_dict[node_type] = Linear(-1, hidden_channels) |
|
|
| self.convs = torch.nn.ModuleList() |
| for _ in range(num_layers): |
| conv = HGTConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), |
| num_heads, group='sum') |
| self.convs.append(conv) |
|
|
| self.lin = Linear(hidden_channels, out_channels) |
|
|
| def forward(self, x_dict, edge_index_dict): |
| for node_type, x in x_dict.items(): |
| x_dict[node_type]=self.lin_dict[node_type](x).relu_() |
|
|
| for conv in self.convs: |
| x_dict = conv(x_dict, edge_index_dict) |
| return self.lin(x_dict['vertices']) |
| class HAN(torch.nn.Module): |
| def __init__(self, hidden_channels, out_channels, num_heads, num_layers): |
| super().__init__() |
|
|
| self.lin_dict = torch.nn.ModuleDict() |
| for node_type in ["vertices"]: |
| self.lin_dict[node_type] = Linear(-1, hidden_channels) |
|
|
| self.convs = torch.nn.ModuleList() |
| for _ in range(num_layers): |
| conv = HANConv(hidden_channels, hidden_channels, (['vertices'],[('vertices', 'inside', 'vertices'), ('vertices', 'apart', 'vertices')]), |
| num_heads) |
| self.convs.append(conv) |
|
|
| self.lin = Linear(hidden_channels, out_channels) |
|
|
| def forward(self, x_dict, edge_index_dict): |
| for node_type, x in x_dict.items(): |
| x_dict[node_type]=self.lin_dict[node_type](x).relu_() |
|
|
| for conv in self.convs: |
| x_dict = conv(x_dict, edge_index_dict) |
| return self.lin(x_dict['vertices']) |
|
|