import torch import numpy as np from tqdm import tqdm from src.tgn.memory import Memory from src.tgn.model import TGN from src.tgn.time_encoding import TimeEncoding def train_tgn(graph_data, batch_size=1024, num_epochs=3): device = torch.device("cpu") edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long) edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32) labels = torch.tensor(graph_data["y"], dtype=torch.float32) x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) # Normalize ALL edge features ea_mean = edge_attr.mean(dim=0) ea_std = edge_attr.std(dim=0) + 1e-6 edge_attr = (edge_attr - ea_mean) / ea_std # Raw timestamps for time encoder (normalized to [0, 1]) timestamps_raw = torch.tensor(graph_data["timestamps"], dtype=torch.float32) t_min = timestamps_raw.min() t_max = timestamps_raw.max() timestamps = (timestamps_raw - t_min) / (t_max - t_min + 1e-6) num_nodes = x.shape[0] model = TGN( memory_dim=64, node_dim=x.shape[1], edge_dim=edge_attr.shape[1], time_dim=16 ).to(device) time_encoder = TimeEncoding(16).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Capped pos_weight raw_pw = (labels == 0).sum().float() / (labels == 1).sum().float() pos_weight = torch.clamp(raw_pw, max=10.0) loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Train on train split ONLY train_mask = graph_data["train_mask"] if isinstance(train_mask, np.ndarray): train_mask = torch.tensor(train_mask, dtype=torch.bool) else: train_mask = torch.tensor(train_mask.values, dtype=torch.bool) train_idx = torch.where(train_mask)[0] N = len(train_idx) for epoch in range(num_epochs): total_loss = 0 memory = Memory(num_nodes, memory_dim=64, device=device) for i in tqdm(range(0, N, batch_size)): batch_ids = train_idx[i:i + batch_size] u = edge_index[0, batch_ids].to(device) v = edge_index[1, batch_ids].to(device) edge_feat = edge_attr[batch_ids].to(device) t = timestamps[batch_ids].to(device) * 5.0 # Amplify time differences to force causality labels_batch = labels[batch_ids].to(device) h_u = memory.get(u) h_v = memory.get(v) x_u = x[u] x_v = x[v] time_enc = time_encoder(t) logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc) logits = torch.clamp(logits, -10, 10) loss = loss_fn(logits, labels_batch) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() # Update memory h_u_new = memory.get(u) h_v_new = memory.get(v) msg = model.compute_message( h_u_new.detach(), h_v_new.detach(), edge_feat, time_enc ) node_ids = torch.cat([u, v]) messages = torch.cat([msg, msg]) unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) agg_msg = torch.zeros_like(memory.memory[unique_nodes]) agg_msg.index_add_(0, inverse_idx, messages) counts = torch.bincount(inverse_idx).unsqueeze(1) agg_msg = agg_msg / counts memory.update(unique_nodes, agg_msg) print(f"Epoch {epoch} Loss: {total_loss:.4f}") # Build memory for inference edges (zero-shot test distribution) if "inference_mask" in graph_data: inf_mask = graph_data["inference_mask"] if isinstance(inf_mask, np.ndarray): inf_mask = torch.tensor(inf_mask, dtype=torch.bool) else: inf_mask = torch.tensor(inf_mask.values, dtype=torch.bool) inf_idx = torch.where(inf_mask)[0] model.eval() with torch.no_grad(): for i in range(0, len(inf_idx), batch_size): batch_ids = inf_idx[i:i + batch_size] u = edge_index[0, batch_ids].to(device) v = edge_index[1, batch_ids].to(device) edge_feat = edge_attr[batch_ids].to(device) t = timestamps[batch_ids].to(device) * 5.0 time_enc = time_encoder(t) h_u_new = memory.get(u) h_v_new = memory.get(v) msg = model.compute_message(h_u_new, h_v_new, edge_feat, time_enc) node_ids = torch.cat([u, v]) messages = torch.cat([msg, msg]) unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) agg_msg = torch.zeros_like(memory.memory[unique_nodes]) agg_msg.index_add_(0, inverse_idx, messages) counts = torch.bincount(inverse_idx).unsqueeze(1) memory.update(unique_nodes, agg_msg / counts) norm_stats = { "ea_mean": ea_mean, "ea_std": ea_std, "t_min": t_min, "t_max": t_max, "x": x, } return model, memory, time_encoder, norm_stats