temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
raw
history blame contribute delete
798 Bytes
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv
class EdgeGNN(nn.Module):
def __init__(self, in_channels, hidden_dim, edge_dim):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
self.edge_mlp = nn.Sequential(
nn.Linear(2 * hidden_dim + edge_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x, edge_index, edge_attr, src, dst):
h = self.conv1(x, edge_index)
h = torch.relu(h)
h = self.conv2(h, edge_index)
h_src = h[src]
h_dst = h[dst]
edge_input = torch.cat([h_src, h_dst, edge_attr], dim=1)
return self.edge_mlp(edge_input).squeeze()