| from typing import Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from .base import LocallyConnected
|
|
|
|
|
| class ResNetBlock(nn.Module):
|
| def __init__(self, in_size=16, h_size=16, out_size=16):
|
| super().__init__()
|
|
|
| layer = nn.ModuleList()
|
| layer.append(nn.Linear(in_features=in_size, out_features=h_size))
|
| layer.append(nn.ReLU())
|
| layer.append(nn.Linear(in_features=h_size, out_features=out_size))
|
|
|
| self.f = nn.Sequential(*layer)
|
| self.shortcut = nn.Sequential()
|
|
|
| def forward(self, x):
|
| return F.relu(self.f(x) + self.shortcut(x))
|
|
|
|
|
| class HyperResNet(nn.Module):
|
| def __init__(self, in_size=16, h_size=16, out_size=16, num_block=2):
|
| super().__init__()
|
|
|
| blocks = nn.ModuleList()
|
| for _ in range(num_block):
|
| blocks.append(ResNetBlock(in_size, h_size, out_size))
|
| self.model = nn.Sequential(*blocks)
|
|
|
| def forward(self, x):
|
| return self.model(x)
|
|
|
|
|
| class HyperLocallyConnected(nn.Module):
|
| """Hyper Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters are learned
|
| from another netwokr:
|
|
|
| y = LocallyConnected_{params}(x),
|
| where params = h(G)
|
|
|
| Args:
|
| num_linear: num of local linear layers, i.e.
|
| in_features: m1
|
| out_features: m2
|
| bias: whether to include bias or not
|
|
|
| Shape:
|
| - Input: [n, d, m1]
|
| - Output: [n, d, m2]
|
|
|
| Attributes:
|
| weight: [d, m1, m2]
|
| bias: [d, m2]
|
| """
|
|
|
| VALID_HYPER = [
|
| "mlp",
|
| "gnn",
|
| "invariant",
|
| "per_graph",
|
| "deep_set",
|
| ]
|
|
|
| def __init__(
|
| self,
|
| num_linear,
|
| input_features,
|
| output_features,
|
| hyper,
|
| n_ens=1,
|
| bias=True,
|
| hyper_hidden_dims: Optional[list] = None,
|
| ):
|
| super().__init__()
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.output_features = output_features
|
| self.n_ens = n_ens
|
| self.hyper = hyper
|
|
|
| assert (
|
| hyper in self.VALID_HYPER
|
| ), f"hyper hparam not a valid option - choices: {self.VALID_HYPER}"
|
|
|
| if hyper == "invariant":
|
| hyper_type = HyperInvariant
|
| elif hyper == "mlp":
|
| hyper_type = HyperMLP
|
| elif hyper == "per_graph":
|
| hyper_type = HyperInvariantPerGraph
|
| elif hyper == "deep_set":
|
| hyper_type = DeepSet
|
|
|
| self.hyper_layer = hyper_type(
|
| n_ens=n_ens,
|
| num_linear=num_linear,
|
| input_features=input_features,
|
| output_features=output_features,
|
| bias=bias,
|
| hidden_dims=hyper_hidden_dims,
|
| )
|
|
|
| def forward(self, x: torch.Tensor, G: torch.Tensor):
|
|
|
| weights, biases = self.hyper_layer(G.to(x))
|
| x = torch.matmul(x, weights.unsqueeze(1))
|
| if biases is not None:
|
|
|
| x += biases.unsqueeze(-2).unsqueeze(1)
|
| return x
|
|
|
|
|
| class AnalyiticLinearLocallyConnected(nn.Module):
|
| """Analytic linear Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters
|
| are learned from another netwokr:
|
|
|
| y = LocallyConnected_{params}(x),
|
| where params = h(G)
|
|
|
| Args:
|
| num_linear: num of local linear layers, i.e.
|
| in_features: m1
|
| out_features: m2
|
| bias: whether to include bias or not
|
|
|
| Shape:
|
| - Input: [n, d, m1]
|
| - Output: [n, d, m2]
|
|
|
| Attributes:
|
| weight: [d, m1, m2]
|
| bias: [d, m2]
|
| """
|
|
|
| def __init__(
|
| self,
|
| num_linear,
|
| input_features,
|
| hyper,
|
| n_ens=1,
|
| bias=True,
|
| hyper_hidden_dims: Optional[list] = None,
|
| ):
|
| super().__init__()
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.n_ens = n_ens
|
| self.hyper = hyper
|
|
|
| self.hyper_layer = HyperAnalyticLinear(
|
| n_ens=n_ens,
|
| num_linear=num_linear,
|
| input_features=input_features,
|
| )
|
|
|
| self.weights = torch.randn((n_ens, num_linear, num_linear))
|
|
|
| def forward(self, x: torch.Tensor, dx: torch.Tensor, G: torch.Tensor):
|
|
|
| self.weights = self.hyper_layer(x, dx, G.to(x))
|
| x = torch.matmul(
|
| self.weights.unsqueeze(1).transpose(-2, -1),
|
| x.squeeze(-2).squeeze(0),
|
| )
|
| return x
|
|
|
|
|
| class NodeHyperLocallyConnected(nn.Module):
|
| """Hyper Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters are learned
|
| from another netwokr:
|
|
|
| y = LocallyConnected_{params}(x),
|
| where params = h(G)
|
|
|
| Args:
|
| num_linear: num of local linear layers, i.e.
|
| in_features: m1
|
| out_features: m2
|
| bias: whether to include bias or not
|
|
|
| Shape:
|
| - Input: [n, d, m1]
|
| - Output: [n, d, m2]
|
|
|
| Attributes:
|
| weight: [d, m1, m2]
|
| bias: [d, m2]
|
| """
|
|
|
| VALID_HYPER = [
|
| "mlp",
|
| "gnn",
|
| "invariant",
|
| "per_graph",
|
| "deep_set",
|
| ]
|
|
|
| def __init__(
|
| self,
|
| num_linear,
|
| input_features,
|
| output_features,
|
| hyper,
|
| n_ens=1,
|
| bias=True,
|
| hyper_hidden_dims: Optional[list] = None,
|
| ):
|
| super().__init__()
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.output_features = output_features
|
| self.n_ens = n_ens
|
| self.hyper = hyper
|
| self.G = None
|
|
|
| assert (
|
| hyper in self.VALID_HYPER
|
| ), f"hyper hparam not a valid option - choices: {self.VALID_HYPER}"
|
|
|
| if hyper == "invariant":
|
| hyper_type = HyperInvariant
|
| elif hyper == "mlp":
|
| hyper_type = HyperMLP
|
| elif hyper == "per_graph":
|
| hyper_type = HyperInvariantPerGraph
|
| elif hyper == "deep_set":
|
| hyper_type = DeepSet
|
|
|
| self.hyper_layer = hyper_type(
|
| n_ens=n_ens,
|
| num_linear=num_linear,
|
| input_features=input_features,
|
| output_features=output_features,
|
| bias=bias,
|
| hidden_dims=hyper_hidden_dims,
|
| )
|
|
|
| def forward(self, x: torch.Tensor):
|
|
|
| G = self.G
|
| weights, biases = self.hyper_layer(G.to(x))
|
| x = torch.matmul(x, weights.unsqueeze(1))
|
| if biases is not None:
|
|
|
| x += biases.unsqueeze(-2).unsqueeze(1)
|
| return x
|
|
|
|
|
| class MLP(nn.Module):
|
| def __init__(self, dims, bias=True):
|
| super().__init__()
|
| self.net = nn.Sequential()
|
| for i in range(len(dims) - 1):
|
| if i > 0:
|
| self.net.append(nn.ELU())
|
| self.net.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class DeepSet(nn.Module):
|
| def __init__(
|
| self,
|
| num_nodes,
|
| input_features,
|
| output_features,
|
| bias=True,
|
| embedding_size: Optional[int] = None,
|
| phi_dims: Optional[list] = None,
|
| f_dims: Optional[list] = None,
|
| **kwargs,
|
| ):
|
| super().__init__()
|
| if embedding_size is None:
|
| embedding_size = 16
|
| if phi_dims is None:
|
| phi_dims = [64, 64]
|
| if f_dims is None:
|
| f_dims = [64, 64]
|
| self.embedding_size = embedding_size
|
| self.phi_dims = phi_dims
|
| self.f_dims = f_dims
|
| self.node_embedding = nn.Parameter(torch.Tensor(embedding_size, num_nodes))
|
| self.phi_net = MLP([embedding_size + input_features, *phi_dims, embedding_size], bias=bias)
|
| self.f_net = MLP([embedding_size, *f_dims, output_features], bias=bias)
|
|
|
| def forward(self, x: torch.Tensor, G: torch.Tensor):
|
|
|
|
|
| del G
|
| x = self.phi_net(torch.cat(self.node_embeddings, x))
|
|
|
| x = torch.sum(x, dim=-2)
|
|
|
| return x
|
|
|
|
|
| class HyperMLP(nn.Module):
|
| """Hypernetwork that takes in a graph (represented as an adjacency matrix) and outputs weights
|
| and biases for a linear layer over each node."""
|
|
|
| def __init__(
|
| self,
|
| num_linear,
|
| input_features,
|
| output_features,
|
| bias=True,
|
| hidden_dims: Optional[list] = None,
|
| **kwargs,
|
| ):
|
| super().__init__()
|
| if hidden_dims is None:
|
|
|
|
|
| hidden_dims = [1024, 512, 128, 64]
|
|
|
| self.dims = hidden_dims
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.output_features = output_features
|
| self.bias = bias
|
|
|
| self.w_features = self.num_linear * self.input_features * self.output_features
|
| self.b_features = self.num_linear * self.output_features
|
| self.total_features = self.w_features
|
| if self.bias:
|
| self.total_features += self.b_features
|
| full_dims = [num_linear**2, *self.dims, self.total_features]
|
| self.net = nn.Sequential()
|
| for i in range(len(full_dims) - 1):
|
| if i > 0:
|
| self.net.append(nn.ELU())
|
| self.net.append(nn.Linear(full_dims[i], full_dims[i + 1]))
|
|
|
| def forward(self, x):
|
|
|
|
|
|
|
| n_ens = x.shape[0]
|
| x = x.reshape(n_ens, -1)
|
| x = self.net(x)
|
| x_w = x[:, : self.w_features].reshape(
|
| n_ens, self.num_linear, self.input_features, self.output_features
|
| )
|
| x_b = None
|
| if self.bias:
|
| x_b = x[:, self.w_features :].reshape(n_ens, self.num_linear, self.output_features)
|
| return x_w, x_b
|
|
|
|
|
| class HyperAnalyticLinear(LocallyConnected):
|
| """Analytic linear hyper-net module.
|
|
|
| Locally connected but directly returns weights
|
| """
|
|
|
| def __init__(
|
| self,
|
| n_ens,
|
| num_linear,
|
| input_features,
|
| ):
|
| super(LocallyConnected, self).__init__()
|
| self.n_ens = n_ens
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.output_features = input_features
|
| self.beta = 0.01
|
|
|
|
|
|
|
|
|
| self.register_parameter("bias", None)
|
|
|
|
|
|
|
| def analytic_linear(self, x, dx, G):
|
| Gt = torch.transpose(G.to(x), -2, -1).unsqueeze(1)
|
| x_masked = Gt * x
|
| A_est = []
|
| for p in range(self.num_linear):
|
| w_est = torch.linalg.solve(
|
| (torch.transpose(x_masked[:, :, p, :], -2, -1) @ x_masked[:, :, p, :])
|
| + self.beta * torch.eye(self.num_linear).unsqueeze(0).type_as(x_masked),
|
| torch.transpose(x_masked[:, :, p, :], -2, -1) @ dx[:, :, p],
|
| )
|
| A_est.append(w_est)
|
| A_est = torch.cat(A_est, dim=2)
|
| return A_est
|
|
|
| def forward(self, x, dx, G):
|
| self.weights = self.analytic_linear(x, dx, G).to(x)
|
| return self.weights
|
|
|
|
|
| class HyperInvariantPerGraph(LocallyConnected):
|
| """Invariant hyper-net module per graph.
|
|
|
| Locally connected but directly returns weights
|
| """
|
|
|
| def __init__(self, n_ens, num_linear, input_features, output_features, bias=True, **kwargs):
|
| super(LocallyConnected, self).__init__()
|
| self.n_ens = n_ens
|
| self.num_linear = num_linear
|
| self.input_features = input_features
|
| self.output_features = output_features
|
|
|
| self.weight = nn.Parameter(
|
| torch.Tensor(n_ens, num_linear, input_features, output_features)
|
| )
|
| if bias:
|
| self.bias = nn.Parameter(torch.Tensor(n_ens, num_linear, output_features))
|
| else:
|
|
|
|
|
| self.register_parameter("bias", None)
|
|
|
| self.reset_parameters()
|
|
|
| def forward(self, input):
|
| return self.weight, self.bias
|
|
|
|
|
| class HyperInvariant(LocallyConnected):
|
| """Invariant hyper-net module.
|
|
|
| Locally connected but directly returns weights
|
| """
|
|
|
| def __init__(self, num_linear, input_features, output_features, bias=True, **kwargs):
|
| super().__init__(num_linear, input_features, output_features, bias)
|
|
|
| def forward(self, input):
|
| return self.weight.unsqueeze(0), self.bias.unsqueeze(0)
|
|
|