import torch from sklearn.metrics import roc_auc_score, average_precision_score def evaluate_gnn(model, graph_data): device = torch.device("cpu") edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device) edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device) x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) y = torch.tensor(graph_data["y"], dtype=torch.float32).to(device) src = edge_index[0] dst = edge_index[1] model.eval() with torch.no_grad(): logits = model(x, edge_index, edge_attr, src, dst) # ✅ FIXED probs = torch.sigmoid(logits).cpu().numpy() y_true = y.cpu().numpy() roc = roc_auc_score(y_true, probs) pr = average_precision_score(y_true, probs) return roc, pr