import torch import torch.nn as nn class TGN(nn.Module): def __init__(self, memory_dim, node_dim, edge_dim, time_dim, hidden_dim=128): super().__init__() self.memory_dim = memory_dim self.node_dim = node_dim self.time_dim = time_dim # ------------------------- # MESSAGE FUNCTION # ------------------------- self.message_mlp = nn.Sequential( nn.Linear(2 * memory_dim + edge_dim + 2 * time_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, memory_dim), ) # ------------------------- # MEMORY UPDATE # ------------------------- self.update_mlp = nn.GRUCell(memory_dim, memory_dim) # ------------------------- # EDGE PREDICTOR (TIME-AWARE) # ------------------------- self.decoder = nn.Sequential( nn.Linear( 2 * (memory_dim + node_dim) + edge_dim + 2 * time_dim, hidden_dim ), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), nn.Linear(hidden_dim // 2, 1), ) # ------------------------- # NODE RISK CLASSIFIER (NEW) # ------------------------- self.node_classifier = nn.Sequential( nn.Linear(memory_dim + node_dim, 64), nn.ReLU(), nn.Linear(64, 1), ) # ------------------------- # MESSAGE COMPUTATION # ------------------------- def compute_message(self, h_u, h_v, edge_attr, time_enc): return self.message_mlp( torch.cat([h_u, h_v, edge_attr, time_enc], dim=1) ) # ------------------------- # MEMORY UPDATE # ------------------------- def update_memory(self, memory, node_ids, messages): updated = self.update_mlp(messages, memory[node_ids]) memory[node_ids] = updated.detach() return memory # ------------------------- # PREDICTION (UPDATED) # ------------------------- def predict(self, h_u, h_v, edge_attr, x_u, x_v, time_enc): return self.decoder( torch.cat([h_u, x_u, h_v, x_v, edge_attr, time_enc], dim=1) ).squeeze(-1) # ------------------------- # NODE PREDICTION (NEW) # ------------------------- def predict_node(self, memory, x): combined = torch.cat([memory, x], dim=1) return self.node_classifier(combined).squeeze(-1)