File size: 2,813 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch

from sklearn.metrics import roc_auc_score, average_precision_score
from src.tgn.time_encoding import TimeEncoding
from src.tgn.memory import Memory


def evaluate(model, memory, graph_data, norm_stats):
    device = torch.device("cpu")

    edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long)
    edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)
    labels = torch.tensor(graph_data["y"], dtype=torch.float32)

    x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device)
    x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)

    # Apply SAME normalization as training
    edge_attr = (edge_attr - norm_stats["ea_mean"]) / norm_stats["ea_std"]

    timestamps = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)[:, 1]
    timestamps = (timestamps - norm_stats["t_min"]) / (norm_stats["t_max"] - norm_stats["t_min"] + 1e-6)

    test_idx = graph_data["test_idx"]
    train_idx = graph_data["train_idx"]

    # Rebuild memory from train edges only
    memory = Memory(x.shape[0], memory_dim=64, device=device)
    time_encoder = TimeEncoding(16).to(device)

    batch_size = 1024

    with torch.no_grad():
        for i in range(0, len(train_idx), batch_size):
            batch_ids = train_idx[i:i + batch_size]

            u_i = edge_index[0, batch_ids]
            v_i = edge_index[1, batch_ids]

            edge_feat_i = edge_attr[batch_ids]
            t_i = timestamps[batch_ids]

            time_enc_i = time_encoder(t_i)

            h_u_i = memory.get(u_i)
            h_v_i = memory.get(v_i)

            msg = model.compute_message(
                h_u_i.detach(), h_v_i.detach(),
                edge_feat_i, time_enc_i
            )

            node_ids = torch.cat([u_i, v_i])
            messages = torch.cat([msg, msg])

            unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True)

            agg_msg = torch.zeros_like(memory.memory[unique_nodes])
            agg_msg.index_add_(0, inverse_idx, messages)

            counts = torch.bincount(inverse_idx).unsqueeze(1)
            agg_msg = agg_msg / counts

            memory.update(unique_nodes, agg_msg)

    # Evaluate on test set
    u = edge_index[0, test_idx].to(device)
    v = edge_index[1, test_idx].to(device)

    h_u = memory.get(u)
    h_v = memory.get(v)

    x_u = x[u]
    x_v = x[v]

    edge_feat = edge_attr[test_idx].to(device)

    with torch.no_grad():
        t = timestamps[test_idx].to(device)
        time_enc = time_encoder(t)

        logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc)
        probs = torch.sigmoid(logits).cpu().numpy()

    y_true = labels[test_idx].cpu().numpy()

    roc = roc_auc_score(y_true, probs)
    pr = average_precision_score(y_true, probs)

    return roc, pr, probs, y_true