File size: 2,479 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn


class TGN(nn.Module):
    def __init__(self, memory_dim, node_dim, edge_dim, time_dim, hidden_dim=128):
        super().__init__()

        self.memory_dim = memory_dim
        self.node_dim = node_dim
        self.time_dim = time_dim

        # -------------------------
        # MESSAGE FUNCTION
        # -------------------------
        self.message_mlp = nn.Sequential(
            nn.Linear(2 * memory_dim + edge_dim + 2 * time_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, memory_dim),
        )

        # -------------------------
        # MEMORY UPDATE
        # -------------------------
        self.update_mlp = nn.GRUCell(memory_dim, memory_dim)

        # -------------------------
        # EDGE PREDICTOR (TIME-AWARE)
        # -------------------------
        self.decoder = nn.Sequential(
            nn.Linear(
                2 * (memory_dim + node_dim) + edge_dim + 2 * time_dim,
                hidden_dim
            ),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
        )

        # -------------------------
        # NODE RISK CLASSIFIER (NEW)
        # -------------------------
        self.node_classifier = nn.Sequential(
            nn.Linear(memory_dim + node_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    # -------------------------
    # MESSAGE COMPUTATION
    # -------------------------
    def compute_message(self, h_u, h_v, edge_attr, time_enc):
        return self.message_mlp(
            torch.cat([h_u, h_v, edge_attr, time_enc], dim=1)
        )

    # -------------------------
    # MEMORY UPDATE
    # -------------------------
    def update_memory(self, memory, node_ids, messages):
        updated = self.update_mlp(messages, memory[node_ids])
        memory[node_ids] = updated.detach()
        return memory

    # -------------------------
    # PREDICTION (UPDATED)
    # -------------------------
    def predict(self, h_u, h_v, edge_attr, x_u, x_v, time_enc):
        return self.decoder(
            torch.cat([h_u, x_u, h_v, x_v, edge_attr, time_enc], dim=1)
        ).squeeze(-1)

    # -------------------------
    # NODE PREDICTION (NEW)
    # -------------------------
    def predict_node(self, memory, x):
        combined = torch.cat([memory, x], dim=1)
        return self.node_classifier(combined).squeeze(-1)