File size: 866 Bytes
a3682cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import pandas as pd
from src.graph.graph_builder import build_edge_index, build_edge_features, build_labels
from src.graph.node_features import build_node_features
from src.graph.temporal_split import temporal_split
def build_graph_dataset(df: pd.DataFrame, users: pd.DataFrame):
edge_index = build_edge_index(df)
edge_attr = build_edge_features(df)
y = build_labels(df)
X = build_node_features(df, users)
# Raw timestamps for TGN time encoding
timestamps = df.sort_values("timestamp").reset_index(drop=True)["timestamp"].values
train_mask, val_mask, test_mask, _ = temporal_split(df)
return {
"edge_index": edge_index,
"edge_attr": edge_attr,
"timestamps": timestamps,
"x": X,
"y": y,
"train_mask": train_mask,
"val_mask": val_mask,
"test_mask": test_mask,
} |