File size: 798 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from torch_geometric.nn import SAGEConv


class EdgeGNN(nn.Module):
    def __init__(self, in_channels, hidden_dim, edge_dim):
        super().__init__()

        self.conv1 = SAGEConv(in_channels, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x, edge_index, edge_attr, src, dst):
        h = self.conv1(x, edge_index)
        h = torch.relu(h)
        h = self.conv2(h, edge_index)

        h_src = h[src]
        h_dst = h[dst]

        edge_input = torch.cat([h_src, h_dst, edge_attr], dim=1)

        return self.edge_mlp(edge_input).squeeze()