File size: 587 Bytes
a3682cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch
from torch.utils.data import Dataset
class EdgeDataset(Dataset):
def __init__(self, edge_index, edge_attr, y, indices):
self.edge_index = edge_index[:, indices]
self.edge_attr = edge_attr[indices]
self.y = y[indices]
def __len__(self):
return self.edge_attr.shape[0]
def __getitem__(self, idx):
src = self.edge_index[0, idx]
dst = self.edge_index[1, idx]
return {
"src": src,
"dst": dst,
"edge_attr": self.edge_attr[idx],
"label": self.y[idx],
} |