| import torch | |
| from torch.utils.data import Dataset | |
| class EdgeDataset(Dataset): | |
| def __init__(self, edge_index, edge_attr, y, indices): | |
| self.edge_index = edge_index[:, indices] | |
| self.edge_attr = edge_attr[indices] | |
| self.y = y[indices] | |
| def __len__(self): | |
| return self.edge_attr.shape[0] | |
| def __getitem__(self, idx): | |
| src = self.edge_index[0, idx] | |
| dst = self.edge_index[1, idx] | |
| return { | |
| "src": src, | |
| "dst": dst, | |
| "edge_attr": self.edge_attr[idx], | |
| "label": self.y[idx], | |
| } |