temporal-twins-code / src /gnn /edge_dataset.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
raw
history blame contribute delete
587 Bytes
import torch
from torch.utils.data import Dataset
class EdgeDataset(Dataset):
def __init__(self, edge_index, edge_attr, y, indices):
self.edge_index = edge_index[:, indices]
self.edge_attr = edge_attr[indices]
self.y = y[indices]
def __len__(self):
return self.edge_attr.shape[0]
def __getitem__(self, idx):
src = self.edge_index[0, idx]
dst = self.edge_index[1, idx]
return {
"src": src,
"dst": dst,
"edge_attr": self.edge_attr[idx],
"label": self.y[idx],
}