Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from torch.nn import Sequential, Linear, BatchNorm1d | |
| import torch_geometric | |
| from torch_geometric.nn import ( | |
| Set2Set, | |
| global_mean_pool, | |
| global_add_pool, | |
| global_max_pool, | |
| CGConv, | |
| ) | |
| from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter | |
| # CGCNN | |
| class CGCNN(torch.nn.Module): | |
| def __init__( | |
| self, | |
| data, | |
| dim1=64, | |
| dim2=64, | |
| pre_fc_count=1, | |
| gc_count=3, | |
| post_fc_count=1, | |
| pool="global_mean_pool", | |
| pool_order="early", | |
| batch_norm="True", | |
| batch_track_stats="True", | |
| act="relu", | |
| dropout_rate=0.0, | |
| **kwargs | |
| ): | |
| super(CGCNN, self).__init__() | |
| if batch_track_stats == "False": | |
| self.batch_track_stats = False | |
| else: | |
| self.batch_track_stats = True | |
| self.batch_norm = batch_norm | |
| self.pool = pool | |
| self.act = act | |
| self.pool_order = pool_order | |
| self.dropout_rate = dropout_rate | |
| ##Determine gc dimension dimension | |
| assert gc_count > 0, "Need at least 1 GC layer" | |
| if pre_fc_count == 0: | |
| gc_dim = data.num_features | |
| else: | |
| gc_dim = dim1 | |
| ##Determine post_fc dimension | |
| if pre_fc_count == 0: | |
| post_fc_dim = data.num_features | |
| else: | |
| post_fc_dim = dim1 | |
| ##Determine output dimension length | |
| if data[0].y.ndim == 0: | |
| output_dim = 1 | |
| else: | |
| output_dim = len(data[0].y[0]) | |
| ##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer) | |
| if pre_fc_count > 0: | |
| self.pre_lin_list = torch.nn.ModuleList() | |
| for i in range(pre_fc_count): | |
| if i == 0: | |
| lin = torch.nn.Linear(data.num_features, dim1) | |
| self.pre_lin_list.append(lin) | |
| else: | |
| lin = torch.nn.Linear(dim1, dim1) | |
| self.pre_lin_list.append(lin) | |
| elif pre_fc_count == 0: | |
| self.pre_lin_list = torch.nn.ModuleList() | |
| ##Set up GNN layers | |
| self.conv_list = torch.nn.ModuleList() | |
| self.bn_list = torch.nn.ModuleList() | |
| for i in range(gc_count): | |
| conv = CGConv( | |
| gc_dim, data.num_edge_features, aggr="mean", batch_norm=False | |
| ) | |
| self.conv_list.append(conv) | |
| ##Track running stats set to false can prevent some instabilities; this causes other issues with different val/test performance from loader size? | |
| if self.batch_norm == "True": | |
| bn = BatchNorm1d(gc_dim, track_running_stats=self.batch_track_stats) | |
| self.bn_list.append(bn) | |
| ##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero) | |
| if post_fc_count > 0: | |
| self.post_lin_list = torch.nn.ModuleList() | |
| for i in range(post_fc_count): | |
| if i == 0: | |
| ##Set2set pooling has doubled dimension | |
| if self.pool_order == "early" and self.pool == "set2set": | |
| lin = torch.nn.Linear(post_fc_dim * 2, dim2) | |
| else: | |
| lin = torch.nn.Linear(post_fc_dim, dim2) | |
| self.post_lin_list.append(lin) | |
| else: | |
| lin = torch.nn.Linear(dim2, dim2) | |
| self.post_lin_list.append(lin) | |
| self.lin_out = torch.nn.Linear(dim2, output_dim) | |
| elif post_fc_count == 0: | |
| self.post_lin_list = torch.nn.ModuleList() | |
| if self.pool_order == "early" and self.pool == "set2set": | |
| self.lin_out = torch.nn.Linear(post_fc_dim*2, output_dim) | |
| else: | |
| self.lin_out = torch.nn.Linear(post_fc_dim, output_dim) | |
| ##Set up set2set pooling (if used) | |
| ##Should processing_setps be a hypereparameter? | |
| if self.pool_order == "early" and self.pool == "set2set": | |
| self.set2set = Set2Set(post_fc_dim, processing_steps=3) | |
| elif self.pool_order == "late" and self.pool == "set2set": | |
| self.set2set = Set2Set(output_dim, processing_steps=3, num_layers=1) | |
| # workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set | |
| self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim) | |
| def forward(self, data): | |
| ##Pre-GNN dense layers | |
| for i in range(0, len(self.pre_lin_list)): | |
| if i == 0: | |
| out = self.pre_lin_list[i](data.x) | |
| out = getattr(F, self.act)(out) | |
| else: | |
| out = self.pre_lin_list[i](out) | |
| out = getattr(F, self.act)(out) | |
| ##GNN layers | |
| for i in range(0, len(self.conv_list)): | |
| if len(self.pre_lin_list) == 0 and i == 0: | |
| if self.batch_norm == "True": | |
| out = self.conv_list[i](data.x, data.edge_index, data.edge_attr) | |
| out = self.bn_list[i](out) | |
| else: | |
| out = self.conv_list[i](data.x, data.edge_index, data.edge_attr) | |
| else: | |
| if self.batch_norm == "True": | |
| out = self.conv_list[i](out, data.edge_index, data.edge_attr) | |
| out = self.bn_list[i](out) | |
| else: | |
| out = self.conv_list[i](out, data.edge_index, data.edge_attr) | |
| #out = getattr(F, self.act)(out) | |
| out = F.dropout(out, p=self.dropout_rate, training=self.training) | |
| ##Post-GNN dense layers | |
| if self.pool_order == "early": | |
| if self.pool == "set2set": | |
| out = self.set2set(out, data.batch) | |
| else: | |
| out = getattr(torch_geometric.nn, self.pool)(out, data.batch) | |
| for i in range(0, len(self.post_lin_list)): | |
| out = self.post_lin_list[i](out) | |
| out = getattr(F, self.act)(out) | |
| out = self.lin_out(out) | |
| elif self.pool_order == "late": | |
| for i in range(0, len(self.post_lin_list)): | |
| out = self.post_lin_list[i](out) | |
| out = getattr(F, self.act)(out) | |
| out = self.lin_out(out) | |
| if self.pool == "set2set": | |
| out = self.set2set(out, data.batch) | |
| out = self.lin_out_2(out) | |
| else: | |
| out = getattr(torch_geometric.nn, self.pool)(out, data.batch) | |
| if out.shape[1] == 1: | |
| return out.view(-1) | |
| else: | |
| return out | |