File size: 587 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch.utils.data import Dataset


class EdgeDataset(Dataset):
    def __init__(self, edge_index, edge_attr, y, indices):
        self.edge_index = edge_index[:, indices]
        self.edge_attr = edge_attr[indices]
        self.y = y[indices]

    def __len__(self):
        return self.edge_attr.shape[0]

    def __getitem__(self, idx):
        src = self.edge_index[0, idx]
        dst = self.edge_index[1, idx]

        return {
            "src": src,
            "dst": dst,
            "edge_attr": self.edge_attr[idx],
            "label": self.y[idx],
        }