import torch from torch.utils.data import Dataset class EdgeDataset(Dataset): def __init__(self, edge_index, edge_attr, y, indices): self.edge_index = edge_index[:, indices] self.edge_attr = edge_attr[indices] self.y = y[indices] def __len__(self): return self.edge_attr.shape[0] def __getitem__(self, idx): src = self.edge_index[0, idx] dst = self.edge_index[1, idx] return { "src": src, "dst": dst, "edge_attr": self.edge_attr[idx], "label": self.y[idx], }