xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import LocallyConnected
class ResNetBlock(nn.Module):
def __init__(self, in_size=16, h_size=16, out_size=16):
super().__init__()
layer = nn.ModuleList()
layer.append(nn.Linear(in_features=in_size, out_features=h_size))
layer.append(nn.ReLU())
layer.append(nn.Linear(in_features=h_size, out_features=out_size))
self.f = nn.Sequential(*layer)
self.shortcut = nn.Sequential()
def forward(self, x):
return F.relu(self.f(x) + self.shortcut(x))
class HyperResNet(nn.Module):
def __init__(self, in_size=16, h_size=16, out_size=16, num_block=2):
super().__init__()
blocks = nn.ModuleList()
for _ in range(num_block):
blocks.append(ResNetBlock(in_size, h_size, out_size))
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class HyperLocallyConnected(nn.Module):
"""Hyper Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters are learned
from another netwokr:
y = LocallyConnected_{params}(x),
where params = h(G)
Args:
num_linear: num of local linear layers, i.e.
in_features: m1
out_features: m2
bias: whether to include bias or not
Shape:
- Input: [n, d, m1]
- Output: [n, d, m2]
Attributes:
weight: [d, m1, m2]
bias: [d, m2]
"""
VALID_HYPER = [
"mlp",
"gnn",
"invariant",
"per_graph",
"deep_set",
]
def __init__(
self,
num_linear,
input_features,
output_features,
hyper,
n_ens=1,
bias=True,
hyper_hidden_dims: Optional[list] = None,
):
super().__init__()
self.num_linear = num_linear
self.input_features = input_features
self.output_features = output_features
self.n_ens = n_ens
self.hyper = hyper
assert (
hyper in self.VALID_HYPER
), f"hyper hparam not a valid option - choices: {self.VALID_HYPER}"
if hyper == "invariant":
hyper_type = HyperInvariant
elif hyper == "mlp":
hyper_type = HyperMLP
elif hyper == "per_graph":
hyper_type = HyperInvariantPerGraph
elif hyper == "deep_set":
hyper_type = DeepSet
self.hyper_layer = hyper_type(
n_ens=n_ens,
num_linear=num_linear,
input_features=input_features,
output_features=output_features,
bias=bias,
hidden_dims=hyper_hidden_dims,
)
def forward(self, x: torch.Tensor, G: torch.Tensor):
# [n_ens, n, d, 1, m2] = [n_ens, n, d, 1, m1] @ [n_ens, 1, d, m1, m2]
weights, biases = self.hyper_layer(G.to(x))
x = torch.matmul(x, weights.unsqueeze(1))
if biases is not None:
# [n, d, m2] += [d, m2]
x += biases.unsqueeze(-2).unsqueeze(1)
return x
class AnalyiticLinearLocallyConnected(nn.Module):
"""Analytic linear Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters
are learned from another netwokr:
y = LocallyConnected_{params}(x),
where params = h(G)
Args:
num_linear: num of local linear layers, i.e.
in_features: m1
out_features: m2
bias: whether to include bias or not
Shape:
- Input: [n, d, m1]
- Output: [n, d, m2]
Attributes:
weight: [d, m1, m2]
bias: [d, m2]
"""
def __init__(
self,
num_linear,
input_features,
hyper,
n_ens=1,
bias=True,
hyper_hidden_dims: Optional[list] = None,
):
super().__init__()
self.num_linear = num_linear
self.input_features = input_features
self.n_ens = n_ens
self.hyper = hyper
self.hyper_layer = HyperAnalyticLinear(
n_ens=n_ens,
num_linear=num_linear,
input_features=input_features,
)
self.weights = torch.randn((n_ens, num_linear, num_linear))
def forward(self, x: torch.Tensor, dx: torch.Tensor, G: torch.Tensor):
# [n_ens, n, d, 1, m2] = [n_ens, n, d, 1, m1] @ [n_ens, 1, d, m1, m2]
self.weights = self.hyper_layer(x, dx, G.to(x))
x = torch.matmul(
self.weights.unsqueeze(1).transpose(-2, -1),
x.squeeze(-2).squeeze(0),
)
return x
class NodeHyperLocallyConnected(nn.Module):
"""Hyper Local linear layer, i.e. Conv1dLocal() with filter size 1 which parameters are learned
from another netwokr:
y = LocallyConnected_{params}(x),
where params = h(G)
Args:
num_linear: num of local linear layers, i.e.
in_features: m1
out_features: m2
bias: whether to include bias or not
Shape:
- Input: [n, d, m1]
- Output: [n, d, m2]
Attributes:
weight: [d, m1, m2]
bias: [d, m2]
"""
VALID_HYPER = [
"mlp",
"gnn",
"invariant",
"per_graph",
"deep_set",
]
def __init__(
self,
num_linear,
input_features,
output_features,
hyper,
n_ens=1,
bias=True,
hyper_hidden_dims: Optional[list] = None,
):
super().__init__()
self.num_linear = num_linear
self.input_features = input_features
self.output_features = output_features
self.n_ens = n_ens
self.hyper = hyper
self.G = None
assert (
hyper in self.VALID_HYPER
), f"hyper hparam not a valid option - choices: {self.VALID_HYPER}"
if hyper == "invariant":
hyper_type = HyperInvariant
elif hyper == "mlp":
hyper_type = HyperMLP
elif hyper == "per_graph":
hyper_type = HyperInvariantPerGraph
elif hyper == "deep_set":
hyper_type = DeepSet
self.hyper_layer = hyper_type(
n_ens=n_ens,
num_linear=num_linear,
input_features=input_features,
output_features=output_features,
bias=bias,
hidden_dims=hyper_hidden_dims,
)
def forward(self, x: torch.Tensor):
# [n_ens, n, d, 1, m2] = [n_ens, n, d, 1, m1] @ [n_ens, 1, d, m1, m2]
G = self.G
weights, biases = self.hyper_layer(G.to(x))
x = torch.matmul(x, weights.unsqueeze(1))
if biases is not None:
# [n, d, m2] += [d, m2]
x += biases.unsqueeze(-2).unsqueeze(1)
return x
class MLP(nn.Module):
def __init__(self, dims, bias=True):
super().__init__()
self.net = nn.Sequential()
for i in range(len(dims) - 1):
if i > 0:
self.net.append(nn.ELU())
self.net.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
def forward(self, x):
return self.net(x)
class DeepSet(nn.Module):
def __init__(
self,
num_nodes,
input_features,
output_features,
bias=True,
embedding_size: Optional[int] = None,
phi_dims: Optional[list] = None,
f_dims: Optional[list] = None,
**kwargs,
):
super().__init__()
if embedding_size is None:
embedding_size = 16
if phi_dims is None:
phi_dims = [64, 64]
if f_dims is None:
f_dims = [64, 64]
self.embedding_size = embedding_size
self.phi_dims = phi_dims
self.f_dims = f_dims
self.node_embedding = nn.Parameter(torch.Tensor(embedding_size, num_nodes))
self.phi_net = MLP([embedding_size + input_features, *phi_dims, embedding_size], bias=bias)
self.f_net = MLP([embedding_size, *f_dims, output_features], bias=bias)
def forward(self, x: torch.Tensor, G: torch.Tensor):
# x: [n_ens, batch, d, 1, d]
# [n_ens, n, d, 1, m2] = [n_ens, n, d, 1, m1] @ [n_ens, 1, d, m1, m2]
del G
x = self.phi_net(torch.cat(self.node_embeddings, x))
# [n_ens, batch, d, 1, emb]
x = torch.sum(x, dim=-2)
return x
class HyperMLP(nn.Module):
"""Hypernetwork that takes in a graph (represented as an adjacency matrix) and outputs weights
and biases for a linear layer over each node."""
def __init__(
self,
num_linear,
input_features,
output_features,
bias=True,
hidden_dims: Optional[list] = None,
**kwargs,
):
super().__init__()
if hidden_dims is None:
# hidden_dims = [64, 64, 64]
# hidden_dims = [64, 64]
hidden_dims = [1024, 512, 128, 64]
# hidden_dims = [1024, 1024, 1024, 64]
self.dims = hidden_dims
self.num_linear = num_linear
self.input_features = input_features
self.output_features = output_features
self.bias = bias
self.w_features = self.num_linear * self.input_features * self.output_features
self.b_features = self.num_linear * self.output_features
self.total_features = self.w_features
if self.bias:
self.total_features += self.b_features
full_dims = [num_linear**2, *self.dims, self.total_features]
self.net = nn.Sequential()
for i in range(len(full_dims) - 1):
if i > 0:
self.net.append(nn.ELU())
self.net.append(nn.Linear(full_dims[i], full_dims[i + 1]))
def forward(self, x):
# input = G ~ A [n_ens x d x d]
# Want: output = |params|
# params = h(G)
n_ens = x.shape[0]
x = x.reshape(n_ens, -1)
x = self.net(x)
x_w = x[:, : self.w_features].reshape(
n_ens, self.num_linear, self.input_features, self.output_features
)
x_b = None
if self.bias:
x_b = x[:, self.w_features :].reshape(n_ens, self.num_linear, self.output_features)
return x_w, x_b
class HyperAnalyticLinear(LocallyConnected):
"""Analytic linear hyper-net module.
Locally connected but directly returns weights
"""
def __init__(
self,
n_ens,
num_linear,
input_features,
):
super(LocallyConnected, self).__init__()
self.n_ens = n_ens
self.num_linear = num_linear
self.input_features = input_features
self.output_features = input_features
self.beta = 0.01 # per-node-GFN = 0.01
# self.weight = nn.Parameter(
# torch.FloatTensor(n_ens, num_linear, num_linear)
# )
self.register_parameter("bias", None)
# self.reset_parameters()
def analytic_linear(self, x, dx, G):
Gt = torch.transpose(G.to(x), -2, -1).unsqueeze(1)
x_masked = Gt * x
A_est = []
for p in range(self.num_linear):
w_est = torch.linalg.solve(
(torch.transpose(x_masked[:, :, p, :], -2, -1) @ x_masked[:, :, p, :])
+ self.beta * torch.eye(self.num_linear).unsqueeze(0).type_as(x_masked),
torch.transpose(x_masked[:, :, p, :], -2, -1) @ dx[:, :, p],
)
A_est.append(w_est)
A_est = torch.cat(A_est, dim=2)
return A_est
def forward(self, x, dx, G):
self.weights = self.analytic_linear(x, dx, G).to(x)
return self.weights
class HyperInvariantPerGraph(LocallyConnected):
"""Invariant hyper-net module per graph.
Locally connected but directly returns weights
"""
def __init__(self, n_ens, num_linear, input_features, output_features, bias=True, **kwargs):
super(LocallyConnected, self).__init__()
self.n_ens = n_ens
self.num_linear = num_linear
self.input_features = input_features
self.output_features = output_features
self.weight = nn.Parameter(
torch.Tensor(n_ens, num_linear, input_features, output_features)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(n_ens, num_linear, output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, input):
return self.weight, self.bias
class HyperInvariant(LocallyConnected):
"""Invariant hyper-net module.
Locally connected but directly returns weights
"""
def __init__(self, num_linear, input_features, output_features, bias=True, **kwargs):
super().__init__(num_linear, input_features, output_features, bias)
def forward(self, input):
return self.weight.unsqueeze(0), self.bias.unsqueeze(0)