import torch import torch.nn as nn from torch_geometric.nn import SAGEConv class EdgeGNN(nn.Module): def __init__(self, in_channels, hidden_dim, edge_dim): super().__init__() self.conv1 = SAGEConv(in_channels, hidden_dim) self.conv2 = SAGEConv(hidden_dim, hidden_dim) self.edge_mlp = nn.Sequential( nn.Linear(2 * hidden_dim + edge_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), ) def forward(self, x, edge_index, edge_attr, src, dst): h = self.conv1(x, edge_index) h = torch.relu(h) h = self.conv2(h, edge_index) h_src = h[src] h_dst = h[dst] edge_input = torch.cat([h_src, h_dst, edge_attr], dim=1) return self.edge_mlp(edge_input).squeeze()