| import torch |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from src.tgn.memory import Memory |
| from src.tgn.model import TGN |
| from src.tgn.time_encoding import TimeEncoding |
|
|
|
|
| def train_tgn(graph_data, batch_size=1024, num_epochs=3): |
| device = torch.device("cpu") |
|
|
| edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long) |
| edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32) |
| labels = torch.tensor(graph_data["y"], dtype=torch.float32) |
|
|
| x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) |
| x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) |
|
|
| |
| ea_mean = edge_attr.mean(dim=0) |
| ea_std = edge_attr.std(dim=0) + 1e-6 |
| edge_attr = (edge_attr - ea_mean) / ea_std |
|
|
| |
| timestamps_raw = torch.tensor(graph_data["timestamps"], dtype=torch.float32) |
| t_min = timestamps_raw.min() |
| t_max = timestamps_raw.max() |
| timestamps = (timestamps_raw - t_min) / (t_max - t_min + 1e-6) |
|
|
| num_nodes = x.shape[0] |
|
|
| model = TGN( |
| memory_dim=64, |
| node_dim=x.shape[1], |
| edge_dim=edge_attr.shape[1], |
| time_dim=16 |
| ).to(device) |
|
|
| time_encoder = TimeEncoding(16).to(device) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
| |
| raw_pw = (labels == 0).sum().float() / (labels == 1).sum().float() |
| pos_weight = torch.clamp(raw_pw, max=10.0) |
| loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
| |
| train_mask = graph_data["train_mask"] |
| if isinstance(train_mask, np.ndarray): |
| train_mask = torch.tensor(train_mask, dtype=torch.bool) |
| else: |
| train_mask = torch.tensor(train_mask.values, dtype=torch.bool) |
| train_idx = torch.where(train_mask)[0] |
|
|
| N = len(train_idx) |
|
|
| for epoch in range(num_epochs): |
| total_loss = 0 |
|
|
| memory = Memory(num_nodes, memory_dim=64, device=device) |
|
|
| for i in tqdm(range(0, N, batch_size)): |
| batch_ids = train_idx[i:i + batch_size] |
|
|
| u = edge_index[0, batch_ids].to(device) |
| v = edge_index[1, batch_ids].to(device) |
|
|
| edge_feat = edge_attr[batch_ids].to(device) |
| t = timestamps[batch_ids].to(device) * 5.0 |
|
|
| labels_batch = labels[batch_ids].to(device) |
|
|
| h_u = memory.get(u) |
| h_v = memory.get(v) |
|
|
| x_u = x[u] |
| x_v = x[v] |
|
|
| time_enc = time_encoder(t) |
|
|
| logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc) |
| logits = torch.clamp(logits, -10, 10) |
|
|
| loss = loss_fn(logits, labels_batch) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| |
| h_u_new = memory.get(u) |
| h_v_new = memory.get(v) |
|
|
| msg = model.compute_message( |
| h_u_new.detach(), |
| h_v_new.detach(), |
| edge_feat, |
| time_enc |
| ) |
|
|
| node_ids = torch.cat([u, v]) |
| messages = torch.cat([msg, msg]) |
|
|
| unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) |
|
|
| agg_msg = torch.zeros_like(memory.memory[unique_nodes]) |
| agg_msg.index_add_(0, inverse_idx, messages) |
|
|
| counts = torch.bincount(inverse_idx).unsqueeze(1) |
| agg_msg = agg_msg / counts |
|
|
| memory.update(unique_nodes, agg_msg) |
|
|
| print(f"Epoch {epoch} Loss: {total_loss:.4f}") |
|
|
| |
| if "inference_mask" in graph_data: |
| inf_mask = graph_data["inference_mask"] |
| if isinstance(inf_mask, np.ndarray): |
| inf_mask = torch.tensor(inf_mask, dtype=torch.bool) |
| else: |
| inf_mask = torch.tensor(inf_mask.values, dtype=torch.bool) |
| |
| inf_idx = torch.where(inf_mask)[0] |
| model.eval() |
| with torch.no_grad(): |
| for i in range(0, len(inf_idx), batch_size): |
| batch_ids = inf_idx[i:i + batch_size] |
| u = edge_index[0, batch_ids].to(device) |
| v = edge_index[1, batch_ids].to(device) |
| edge_feat = edge_attr[batch_ids].to(device) |
| t = timestamps[batch_ids].to(device) * 5.0 |
| |
| time_enc = time_encoder(t) |
| h_u_new = memory.get(u) |
| h_v_new = memory.get(v) |
|
|
| msg = model.compute_message(h_u_new, h_v_new, edge_feat, time_enc) |
| node_ids = torch.cat([u, v]) |
| messages = torch.cat([msg, msg]) |
|
|
| unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) |
| agg_msg = torch.zeros_like(memory.memory[unique_nodes]) |
| agg_msg.index_add_(0, inverse_idx, messages) |
| counts = torch.bincount(inverse_idx).unsqueeze(1) |
| |
| memory.update(unique_nodes, agg_msg / counts) |
|
|
| norm_stats = { |
| "ea_mean": ea_mean, "ea_std": ea_std, |
| "t_min": t_min, "t_max": t_max, |
| "x": x, |
| } |
|
|
| return model, memory, time_encoder, norm_stats |