| import pickle | |
| import time | |
| from src.gnn.train import train_gnn | |
| from src.gnn.evaluate import evaluate_gnn | |
| def main(): | |
| start = time.time() | |
| with open("data/graph/graph.pkl", "rb") as f: | |
| graph_data = pickle.load(f) | |
| model = train_gnn(graph_data) | |
| end = time.time() | |
| print("Training complete") | |
| print(f"Total runtime: {end - start:.2f} seconds") | |
| roc, pr = evaluate_gnn(model, graph_data) | |
| print(f"GNN ROC-AUC: {roc:.4f}") | |
| print(f"GNN PR-AUC: {pr:.4f}") | |
| if __name__ == "__main__": | |
| main() |