temporal-twins-code / src /graph /dataset_builder.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
raw
history blame
866 Bytes
import pandas as pd
from src.graph.graph_builder import build_edge_index, build_edge_features, build_labels
from src.graph.node_features import build_node_features
from src.graph.temporal_split import temporal_split
def build_graph_dataset(df: pd.DataFrame, users: pd.DataFrame):
edge_index = build_edge_index(df)
edge_attr = build_edge_features(df)
y = build_labels(df)
X = build_node_features(df, users)
# Raw timestamps for TGN time encoding
timestamps = df.sort_values("timestamp").reset_index(drop=True)["timestamp"].values
train_mask, val_mask, test_mask, _ = temporal_split(df)
return {
"edge_index": edge_index,
"edge_attr": edge_attr,
"timestamps": timestamps,
"x": X,
"y": y,
"train_mask": train_mask,
"val_mask": val_mask,
"test_mask": test_mask,
}