| import pickle | |
| import time | |
| from src.tgn.train import train_tgn | |
| from src.tgn.evaluate import evaluate | |
| def main(): | |
| start = time.time() | |
| with open("data/graph/graph.pkl", "rb") as f: | |
| graph_data = pickle.load(f) | |
| model, memory, norm_stats = train_tgn(graph_data) | |
| end = time.time() | |
| print("Training complete") | |
| print(f"Total runtime: {end - start:.2f} seconds") | |
| roc, pr, probs, y_true = evaluate(model, memory, graph_data, norm_stats) | |
| print(f"ROC-AUC: {roc:.4f}") | |
| print(f"PR-AUC: {pr:.4f}") | |
| if __name__ == "__main__": | |
| main() |