| """ |
| (c) Adaptation of the code from https://github.com/SitaoLuan/ACM-GNN |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from torch import Tensor |
| from typing import Union |
| from torch_geometric.nn.conv import MessagePassing |
| from torch_geometric.nn.inits import reset |
| from torch_geometric.typing import OptPairTensor, OptTensor, Size |
| from torch_geometric.utils import scatter |
|
|
| from .utils import create_activation |
|
|
|
|
| class ACM_GIN(MessagePassing): |
| def __init__( |
| self, |
| nn_lowpass: torch.nn.Module, |
| nn_highpass: torch.nn.Module, |
| nn_fullpass: torch.nn.Module, |
| nn_lowpass_proj: torch.nn.Module, |
| nn_highpass_proj: torch.nn.Module, |
| nn_fullpass_proj: torch.nn.Module, |
| nn_mix: torch.nn.Module, |
| T: float = 3.0, |
| **kwargs, |
| ): |
| kwargs.setdefault("aggr", "add") |
| super().__init__(**kwargs) |
| self.nn_lowpass = nn_lowpass |
| self.nn_highpass = nn_highpass |
| self.nn_fullpass = nn_fullpass |
| self.nn_lowpass_proj = nn_lowpass_proj |
| self.nn_highpass_proj = nn_highpass_proj |
| self.nn_fullpass_proj = nn_fullpass_proj |
| self.nn_mix = nn_mix |
| self.sigmoid = torch.nn.Sigmoid() |
| self.softmax = torch.nn.Softmax(dim=1) |
| self.T = T |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| reset(self.nn_lowpass) |
| reset(self.nn_highpass) |
| reset(self.nn_fullpass) |
| reset(self.nn_lowpass_proj) |
| reset(self.nn_highpass_proj) |
| reset(self.nn_fullpass_proj) |
| reset(self.nn_mix) |
|
|
| def forward( |
| self, |
| x: Union[Tensor, OptPairTensor], |
| edge_index: Tensor, |
| edge_weight: OptTensor = None, |
| size: Size = None, |
| ) -> Tensor: |
|
|
| if isinstance(x, Tensor): |
| x: OptPairTensor = (x, x) |
|
|
| |
| out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size) |
|
|
| deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum") |
| deg_inv = 1.0 / deg |
| deg_inv.masked_fill_(deg_inv == float("inf"), 0) |
| out = deg_inv.view(-1, 1) * out |
|
|
| x_r = x[1] |
| if x_r is not None: |
| out_lowpass = (x_r + out) / 2.0 |
| out_highpass = (x_r - out) / 2.0 |
|
|
| |
| out_lowpass = self.nn_lowpass(out_lowpass) |
| out_highpass = self.nn_highpass(out_highpass) |
| out_fullpass = self.nn_fullpass(x_r) |
| |
| alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass)) |
| alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass)) |
| alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass)) |
| alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1) |
| alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T)) |
|
|
| out = alpha_cat[:, 0].view(-1, 1) * out_lowpass |
| out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass |
| out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass |
|
|
| return out |
|
|
| def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: |
| return edge_weight.view(-1, 1) * x_j |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(nn={self.nn})" |
|
|
|
|
| class ACM_GIN_model(nn.Module): |
| """ """ |
|
|
| def __init__( |
| self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu" |
| ): |
| super(ACM_GIN_model, self).__init__() |
| self.num_layers = num_layers |
| self.hidden_dim = hidden_dim |
| self.gnn_batchnorm = batchnorm |
| self.out_dim = out_dim |
|
|
| self.ACM_convs = nn.ModuleList() |
| self.nns_lowpass = nn.ModuleList() |
| self.nns_highpass = nn.ModuleList() |
| self.nns_fullpass = nn.ModuleList() |
| self.nns_lowpass_proj = nn.ModuleList() |
| self.nns_highpass_proj = nn.ModuleList() |
| self.nns_fullpass_proj = nn.ModuleList() |
| self.nns_mix = nn.ModuleList() |
|
|
| self.activation = create_activation(activation) |
|
|
| for i in range(self.num_layers): |
| |
| for channel_proj_module in [ |
| self.nns_lowpass_proj, |
| self.nns_highpass_proj, |
| self.nns_fullpass_proj, |
| ]: |
| if i == self.num_layers - 1: |
| channel_proj_module.append(nn.Linear(self.out_dim, 1)) |
| else: |
| channel_proj_module.append(nn.Linear(self.hidden_dim, 1)) |
| |
| self.nns_mix.append(nn.Linear(3, 3)) |
|
|
| |
| if i == 0: |
| local_input_dim = in_dim |
| else: |
| local_input_dim = self.hidden_dim |
|
|
| if i == self.num_layers - 1: |
| local_out_dim = self.out_dim |
| else: |
| local_out_dim = self.hidden_dim |
|
|
| for channel_module in [ |
| self.nns_lowpass, |
| self.nns_highpass, |
| self.nns_fullpass, |
| ]: |
| if self.gnn_batchnorm: |
| sequential = nn.Sequential( |
| nn.Linear(local_input_dim, self.hidden_dim), |
| nn.BatchNorm1d(self.hidden_dim), |
| self.activation, |
| nn.Linear(self.hidden_dim, local_out_dim), |
| nn.BatchNorm1d(local_out_dim), |
| self.activation, |
| ) |
| else: |
| sequential = nn.Sequential( |
| nn.Linear(local_input_dim, self.hidden_dim), |
| self.activation, |
| nn.Linear(self.hidden_dim, local_out_dim), |
| self.activation, |
| ) |
|
|
| channel_module.append(sequential) |
|
|
| self.ACM_convs.append( |
| ACM_GIN( |
| nn_lowpass=self.nns_lowpass[i], |
| nn_highpass=self.nns_highpass[i], |
| nn_fullpass=self.nns_fullpass[i], |
| nn_lowpass_proj=self.nns_lowpass_proj[i], |
| nn_highpass_proj=self.nns_highpass_proj[i], |
| nn_fullpass_proj=self.nns_fullpass_proj[i], |
| nn_mix=self.nns_mix[i], |
| ) |
| ) |
|
|
| def reset_parameters(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| m.reset_parameters() |
| elif isinstance(m, nn.BatchNorm1d): |
| m.reset_parameters() |
|
|
| def forward(self, x, edge_index, edge_attr, return_hidden=False): |
| outs = [] |
| for i in range(self.num_layers): |
| x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr) |
| outs.append(x) |
| if return_hidden: |
| return x, outs |
| else: |
| return x |
|
|
|
|
| if __name__ == "__main__": |
| acm_gin = ACM_GIN_model(46, 46, 2, 256, True) |
| print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad)) |
| print("") |
|
|