import torch from sklearn.metrics import roc_auc_score, average_precision_score from src.tgn.time_encoding import TimeEncoding from src.tgn.memory import Memory def evaluate(model, memory, graph_data, norm_stats): device = torch.device("cpu") edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long) edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32) labels = torch.tensor(graph_data["y"], dtype=torch.float32) x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) # Apply SAME normalization as training edge_attr = (edge_attr - norm_stats["ea_mean"]) / norm_stats["ea_std"] timestamps = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)[:, 1] timestamps = (timestamps - norm_stats["t_min"]) / (norm_stats["t_max"] - norm_stats["t_min"] + 1e-6) test_idx = graph_data["test_idx"] train_idx = graph_data["train_idx"] # Rebuild memory from train edges only memory = Memory(x.shape[0], memory_dim=64, device=device) time_encoder = TimeEncoding(16).to(device) batch_size = 1024 with torch.no_grad(): for i in range(0, len(train_idx), batch_size): batch_ids = train_idx[i:i + batch_size] u_i = edge_index[0, batch_ids] v_i = edge_index[1, batch_ids] edge_feat_i = edge_attr[batch_ids] t_i = timestamps[batch_ids] time_enc_i = time_encoder(t_i) h_u_i = memory.get(u_i) h_v_i = memory.get(v_i) msg = model.compute_message( h_u_i.detach(), h_v_i.detach(), edge_feat_i, time_enc_i ) node_ids = torch.cat([u_i, v_i]) messages = torch.cat([msg, msg]) unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) agg_msg = torch.zeros_like(memory.memory[unique_nodes]) agg_msg.index_add_(0, inverse_idx, messages) counts = torch.bincount(inverse_idx).unsqueeze(1) agg_msg = agg_msg / counts memory.update(unique_nodes, agg_msg) # Evaluate on test set u = edge_index[0, test_idx].to(device) v = edge_index[1, test_idx].to(device) h_u = memory.get(u) h_v = memory.get(v) x_u = x[u] x_v = x[v] edge_feat = edge_attr[test_idx].to(device) with torch.no_grad(): t = timestamps[test_idx].to(device) time_enc = time_encoder(t) logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc) probs = torch.sigmoid(logits).cpu().numpy() y_true = labels[test_idx].cpu().numpy() roc = roc_auc_score(y_true, probs) pr = average_precision_score(y_true, probs) return roc, pr, probs, y_true