| import torch |
| import torch.nn as nn |
| from torch_geometric.nn import SAGEConv |
|
|
|
|
| class EdgeGNN(nn.Module): |
| def __init__(self, in_channels, hidden_dim, edge_dim): |
| super().__init__() |
|
|
| self.conv1 = SAGEConv(in_channels, hidden_dim) |
| self.conv2 = SAGEConv(hidden_dim, hidden_dim) |
|
|
| self.edge_mlp = nn.Sequential( |
| nn.Linear(2 * hidden_dim + edge_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def forward(self, x, edge_index, edge_attr, src, dst): |
| h = self.conv1(x, edge_index) |
| h = torch.relu(h) |
| h = self.conv2(h, edge_index) |
|
|
| h_src = h[src] |
| h_dst = h[dst] |
|
|
| edge_input = torch.cat([h_src, h_dst, edge_attr], dim=1) |
|
|
| return self.edge_mlp(edge_input).squeeze() |