SEUyishu's picture
Upload 46 files
dfc4f2b verified
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
import torch_geometric
from torch_geometric.nn import (
Set2Set,
global_mean_pool,
global_add_pool,
global_max_pool,
MetaLayer,
)
from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter
# Megnet
class Megnet_EdgeModel(torch.nn.Module):
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
super(Megnet_EdgeModel, self).__init__()
self.act=act
self.fc_layers = fc_layers
if batch_track_stats == "False":
self.batch_track_stats = False
else:
self.batch_track_stats = True
self.batch_norm = batch_norm
self.dropout_rate = dropout_rate
self.edge_mlp = torch.nn.ModuleList()
self.bn_list = torch.nn.ModuleList()
for i in range(self.fc_layers + 1):
if i == 0:
lin = torch.nn.Linear(dim * 4, dim)
self.edge_mlp.append(lin)
else:
lin = torch.nn.Linear(dim, dim)
self.edge_mlp.append(lin)
if self.batch_norm == "True":
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
self.bn_list.append(bn)
def forward(self, src, dest, edge_attr, u, batch):
comb = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
for i in range(0, len(self.edge_mlp)):
if i == 0:
out = self.edge_mlp[i](comb)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
else:
out = self.edge_mlp[i](out)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
return out
class Megnet_NodeModel(torch.nn.Module):
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
super(Megnet_NodeModel, self).__init__()
self.act=act
self.fc_layers = fc_layers
if batch_track_stats == "False":
self.batch_track_stats = False
else:
self.batch_track_stats = True
self.batch_norm = batch_norm
self.dropout_rate = dropout_rate
self.node_mlp = torch.nn.ModuleList()
self.bn_list = torch.nn.ModuleList()
for i in range(self.fc_layers + 1):
if i == 0:
lin = torch.nn.Linear(dim * 3, dim)
self.node_mlp.append(lin)
else:
lin = torch.nn.Linear(dim, dim)
self.node_mlp.append(lin)
if self.batch_norm == "True":
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
self.bn_list.append(bn)
def forward(self, x, edge_index, edge_attr, u, batch):
# row, col = edge_index
v_e = scatter_mean(edge_attr, edge_index[0, :], dim=0)
comb = torch.cat([x, v_e, u[batch]], dim=1)
for i in range(0, len(self.node_mlp)):
if i == 0:
out = self.node_mlp[i](comb)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
else:
out = self.node_mlp[i](out)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
return out
class Megnet_GlobalModel(torch.nn.Module):
def __init__(self, dim, act, batch_norm, batch_track_stats, dropout_rate, fc_layers=2):
super(Megnet_GlobalModel, self).__init__()
self.act=act
self.fc_layers = fc_layers
if batch_track_stats == "False":
self.batch_track_stats = False
else:
self.batch_track_stats = True
self.batch_norm = batch_norm
self.dropout_rate = dropout_rate
self.global_mlp = torch.nn.ModuleList()
self.bn_list = torch.nn.ModuleList()
for i in range(self.fc_layers + 1):
if i == 0:
lin = torch.nn.Linear(dim * 3, dim)
self.global_mlp.append(lin)
else:
lin = torch.nn.Linear(dim, dim)
self.global_mlp.append(lin)
if self.batch_norm == "True":
bn = BatchNorm1d(dim, track_running_stats=self.batch_track_stats)
self.bn_list.append(bn)
def forward(self, x, edge_index, edge_attr, u, batch):
u_e = scatter_mean(edge_attr, edge_index[0, :], dim=0)
u_e = scatter_mean(u_e, batch, dim=0)
u_v = scatter_mean(x, batch, dim=0)
comb = torch.cat([u_e, u_v, u], dim=1)
for i in range(0, len(self.global_mlp)):
if i == 0:
out = self.global_mlp[i](comb)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
else:
out = self.global_mlp[i](out)
out = getattr(F, self.act)(out)
if self.batch_norm == "True":
out = self.bn_list[i](out)
out = F.dropout(out, p=self.dropout_rate, training=self.training)
return out
class MEGNet(torch.nn.Module):
def __init__(
self,
data,
dim1=64,
dim2=64,
dim3=64,
pre_fc_count=1,
gc_count=3,
gc_fc_count=2,
post_fc_count=1,
pool="global_mean_pool",
pool_order="early",
batch_norm="True",
batch_track_stats="True",
act="relu",
dropout_rate=0.0,
**kwargs
):
super(MEGNet, self).__init__()
if batch_track_stats == "False":
self.batch_track_stats = False
else:
self.batch_track_stats = True
self.batch_norm = batch_norm
self.pool = pool
if pool == "global_mean_pool":
self.pool_reduce="mean"
elif pool== "global_max_pool":
self.pool_reduce="max"
elif pool== "global_sum_pool":
self.pool_reduce="sum"
self.act = act
self.pool_order = pool_order
self.dropout_rate = dropout_rate
##Determine gc dimension dimension
assert gc_count > 0, "Need at least 1 GC layer"
if pre_fc_count == 0:
gc_dim = data.num_features
else:
gc_dim = dim1
##Determine post_fc dimension
post_fc_dim = dim3
##Determine output dimension length
if data[0].y.ndim == 0:
output_dim = 1
else:
output_dim = len(data[0].y[0])
##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer)
if pre_fc_count > 0:
self.pre_lin_list = torch.nn.ModuleList()
for i in range(pre_fc_count):
if i == 0:
lin = torch.nn.Linear(data.num_features, dim1)
self.pre_lin_list.append(lin)
else:
lin = torch.nn.Linear(dim1, dim1)
self.pre_lin_list.append(lin)
elif pre_fc_count == 0:
self.pre_lin_list = torch.nn.ModuleList()
##Set up GNN layers
self.e_embed_list = torch.nn.ModuleList()
self.x_embed_list = torch.nn.ModuleList()
self.u_embed_list = torch.nn.ModuleList()
self.conv_list = torch.nn.ModuleList()
self.bn_list = torch.nn.ModuleList()
for i in range(gc_count):
if i == 0:
e_embed = Sequential(
Linear(data.num_edge_features, dim3), ReLU(), Linear(dim3, dim3), ReLU()
)
x_embed = Sequential(
Linear(gc_dim, dim3), ReLU(), Linear(dim3, dim3), ReLU()
)
u_embed = Sequential(
Linear((data[0].u.shape[1]), dim3), ReLU(), Linear(dim3, dim3), ReLU()
)
self.e_embed_list.append(e_embed)
self.x_embed_list.append(x_embed)
self.u_embed_list.append(u_embed)
self.conv_list.append(
MetaLayer(
Megnet_EdgeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
Megnet_NodeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
Megnet_GlobalModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
)
)
elif i > 0:
e_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
x_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
u_embed = Sequential(Linear(dim3, dim3), ReLU(), Linear(dim3, dim3), ReLU())
self.e_embed_list.append(e_embed)
self.x_embed_list.append(x_embed)
self.u_embed_list.append(u_embed)
self.conv_list.append(
MetaLayer(
Megnet_EdgeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
Megnet_NodeModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
Megnet_GlobalModel(dim3, self.act, self.batch_norm, self.batch_track_stats, self.dropout_rate, gc_fc_count),
)
)
##Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero)
if post_fc_count > 0:
self.post_lin_list = torch.nn.ModuleList()
for i in range(post_fc_count):
if i == 0:
##Set2set pooling has doubled dimension
if self.pool_order == "early" and self.pool == "set2set":
lin = torch.nn.Linear(post_fc_dim * 5, dim2)
elif self.pool_order == "early" and self.pool != "set2set":
lin = torch.nn.Linear(post_fc_dim * 3, dim2)
elif self.pool_order == "late":
lin = torch.nn.Linear(post_fc_dim, dim2)
self.post_lin_list.append(lin)
else:
lin = torch.nn.Linear(dim2, dim2)
self.post_lin_list.append(lin)
self.lin_out = torch.nn.Linear(dim2, output_dim)
elif post_fc_count == 0:
self.post_lin_list = torch.nn.ModuleList()
if self.pool_order == "early" and self.pool == "set2set":
self.lin_out = torch.nn.Linear(post_fc_dim * 5, output_dim)
elif self.pool_order == "early" and self.pool != "set2set":
self.lin_out = torch.nn.Linear(post_fc_dim * 3, output_dim)
else:
self.lin_out = torch.nn.Linear(post_fc_dim, output_dim)
##Set up set2set pooling (if used)
if self.pool_order == "early" and self.pool == "set2set":
self.set2set_x = Set2Set(post_fc_dim, processing_steps=3)
self.set2set_e = Set2Set(post_fc_dim, processing_steps=3)
elif self.pool_order == "late" and self.pool == "set2set":
self.set2set_x = Set2Set(output_dim, processing_steps=3, num_layers=1)
# workaround for doubled dimension by set2set; if late pooling not reccomended to use set2set
self.lin_out_2 = torch.nn.Linear(output_dim * 2, output_dim)
def forward(self, data):
##Pre-GNN dense layers
for i in range(0, len(self.pre_lin_list)):
if i == 0:
out = self.pre_lin_list[i](data.x)
out = getattr(F, self.act)(out)
else:
out = self.pre_lin_list[i](out)
out = getattr(F, self.act)(out)
##GNN layers
for i in range(0, len(self.conv_list)):
if i == 0:
if len(self.pre_lin_list) == 0:
e_temp = self.e_embed_list[i](data.edge_attr)
x_temp = self.x_embed_list[i](data.x)
u_temp = self.u_embed_list[i](data.u)
x_out, e_out, u_out = self.conv_list[i](
x_temp, data.edge_index, e_temp, u_temp, data.batch
)
x = torch.add(x_out, x_temp)
e = torch.add(e_out, e_temp)
u = torch.add(u_out, u_temp)
else:
e_temp = self.e_embed_list[i](data.edge_attr)
x_temp = self.x_embed_list[i](out)
u_temp = self.u_embed_list[i](data.u)
x_out, e_out, u_out = self.conv_list[i](
x_temp, data.edge_index, e_temp, u_temp, data.batch
)
x = torch.add(x_out, x_temp)
e = torch.add(e_out, e_temp)
u = torch.add(u_out, u_temp)
elif i > 0:
e_temp = self.e_embed_list[i](e)
x_temp = self.x_embed_list[i](x)
u_temp = self.u_embed_list[i](u)
x_out, e_out, u_out = self.conv_list[i](
x_temp, data.edge_index, e_temp, u_temp, data.batch
)
x = torch.add(x_out, x)
e = torch.add(e_out, e)
u = torch.add(u_out, u)
##Post-GNN dense layers
if self.pool_order == "early":
if self.pool == "set2set":
x_pool = self.set2set_x(x, data.batch)
e = scatter(e, data.edge_index[0, :], dim=0, reduce="mean")
e_pool = self.set2set_e(e, data.batch)
out = torch.cat([x_pool, e_pool, u], dim=1)
else:
x_pool = scatter(x, data.batch, dim=0, reduce=self.pool_reduce)
e_pool = scatter(e, data.edge_index[0, :], dim=0, reduce=self.pool_reduce)
e_pool = scatter(e_pool, data.batch, dim=0, reduce=self.pool_reduce)
out = torch.cat([x_pool, e_pool, u], dim=1)
for i in range(0, len(self.post_lin_list)):
out = self.post_lin_list[i](out)
out = getattr(F, self.act)(out)
out = self.lin_out(out)
##currently only uses node features for late pooling
elif self.pool_order == "late":
out = x
for i in range(0, len(self.post_lin_list)):
out = self.post_lin_list[i](out)
out = getattr(F, self.act)(out)
out = self.lin_out(out)
if self.pool == "set2set":
out = self.set2set_x(out, data.batch)
out = self.lin_out_2(out)
else:
out = getattr(torch_geometric.nn, self.pool)(out, data.batch)
if out.shape[1] == 1:
return out.view(-1)
else:
return out