| """ |
| models/tgat.py |
| ============== |
| Temporal Graph Attention Network (TGAT) |
| Xu et al., "Inductive Representation Learning on Temporal Graphs" (ICLR 2020) |
| |
| Architecture |
| ------------ |
| - Sinusoidal time encoding (reuses src/tgn/time_encoding.py) |
| - Per-node ring buffer of K most recent temporal neighbors |
| - Multi-head scaled dot-product attention over temporal neighborhood |
| - GRU-cell aggregator updates node memory after each event |
| - Node classifier head: memory → fraud probability |
| |
| Event processing (streaming, chronological): |
| For each edge (u, v, t, edge_feat): |
| 1. Retrieve last K neighbors of u from buffer → {(t_i, h_i, e_i)} |
| 2. Build query: Q = W_q(cat(h_u, φ(0))) [current state at t] |
| Build keys: K = W_k(cat(h_i, φ(t−t_i))) [neighbor state at t_i] |
| Build vals: V = W_v(cat(h_i, e_i, φ(t−t_i))) [neighbor context] |
| 3. attn = softmax(Q K^T / √d), z = attn·V |
| 4. h_u ← GRU(z, h_u) [update sender memory] |
| 5. Symmetrically update h_v using u's neighborhood |
| 6. Append (t, h_u, h_v, e) to neighbor buffers |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from models.base import TemporalModel |
| from models.tgn_wrapper import _make_users_df |
| from src.graph.graph_builder import build_edge_features |
| from src.tgn.time_encoding import TimeEncoding |
|
|
|
|
| |
| |
| |
|
|
| class _TGATModule(nn.Module): |
| def __init__( |
| self, |
| memory_dim: int, |
| edge_dim: int, |
| time_dim: int, |
| num_heads: int, |
| ): |
| super().__init__() |
| self.memory_dim = memory_dim |
| self.time_enc = TimeEncoding(time_dim) |
|
|
| |
| q_in = memory_dim + 2 * time_dim |
| kv_base = memory_dim + 2 * time_dim |
| v_in = memory_dim + edge_dim + 2 * time_dim |
|
|
| self.attn_dim = memory_dim |
| self.num_heads = num_heads |
| assert self.attn_dim % num_heads == 0, "attn_dim must be divisible by num_heads" |
|
|
| self.W_q = nn.Linear(q_in, self.attn_dim, bias=False) |
| self.W_k = nn.Linear(kv_base, self.attn_dim, bias=False) |
| self.W_v = nn.Linear(v_in, self.attn_dim, bias=False) |
|
|
| self.scale = (self.attn_dim // num_heads) ** -0.5 |
|
|
| |
| self.merge = nn.Linear(self.attn_dim + memory_dim, memory_dim) |
| self.gru = nn.GRUCell(memory_dim, memory_dim) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(memory_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ) |
|
|
| def attend( |
| self, |
| h_u: torch.Tensor, |
| h_nbrs: torch.Tensor, |
| e_nbrs: torch.Tensor, |
| dt_nbrs: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute multi-head attention over temporal neighborhood.""" |
| B, K = dt_nbrs.shape |
| H = self.num_heads |
| d_h = self.attn_dim // H |
|
|
| phi_0 = self.time_enc(torch.zeros(B, device=h_u.device)) |
| phi_dt = self.time_enc(dt_nbrs.reshape(-1)).reshape(B, K, -1) |
|
|
| |
| q_in = torch.cat([h_u, phi_0], dim=-1) |
| Q = self.W_q(q_in).view(B, H, d_h) |
|
|
| |
| h_nbrs_flat = h_nbrs.reshape(B * K, -1) |
| phi_dt_flat = phi_dt.reshape(B * K, -1) |
| k_in = torch.cat([h_nbrs_flat, phi_dt_flat], dim=-1) |
| K_ = self.W_k(k_in).view(B, K, H, d_h) |
| K_ = K_.permute(0, 2, 1, 3) |
|
|
| |
| v_in = torch.cat([h_nbrs_flat, e_nbrs.reshape(B * K, -1), phi_dt_flat], dim=-1) |
| V = self.W_v(v_in).view(B, K, H, d_h) |
| V = V.permute(0, 2, 1, 3) |
|
|
| |
| scores = (Q.unsqueeze(2) @ K_.transpose(-2, -1)).squeeze(2) |
| scores = scores * self.scale |
|
|
| |
| if mask is not None: |
| inv_mask = ~mask.unsqueeze(1) |
| scores = scores.masked_fill(inv_mask, float("-inf")) |
|
|
| attn = F.softmax(scores, dim=-1) |
| attn = torch.nan_to_num(attn, nan=0.0) |
|
|
| |
| z = (attn.unsqueeze(-1) * V).sum(dim=2) |
| z = z.reshape(B, self.attn_dim) |
|
|
| return z |
|
|
| def update(self, h_u: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
| merged = self.merge(torch.cat([z, h_u], dim=-1)) |
| return self.gru(merged, h_u) |
|
|
| def classify(self, memory: torch.Tensor) -> torch.Tensor: |
| return self.classifier(memory).squeeze(-1) |
|
|
|
|
| |
| |
| |
|
|
| class _TGATStreamer: |
| """ |
| Maintains per-node memory and temporal neighbor buffers. |
| Processes events in a batched manner (approximate — same-batch |
| events use pre-batch memory state, standard practice for scalability). |
| """ |
|
|
| def __init__( |
| self, |
| module: _TGATModule, |
| n_nodes: int, |
| memory_dim: int, |
| edge_dim: int, |
| n_neighbors: int, |
| device: torch.device, |
| ): |
| self.module = module |
| self.memory_dim = memory_dim |
| self.edge_dim = edge_dim |
| self.n_neighbors = n_neighbors |
| self.device = device |
|
|
| |
| self.memory = torch.zeros(n_nodes, memory_dim, device=device) |
|
|
| |
| |
| self.nbr_times: List[List[float]] = [[] for _ in range(n_nodes)] |
| self.nbr_h: List[List[torch.Tensor]] = [[] for _ in range(n_nodes)] |
| self.nbr_e: List[List[torch.Tensor]] = [[] for _ in range(n_nodes)] |
|
|
| def _write_memory_rows( |
| self, |
| node_ids: torch.Tensor, |
| values: torch.Tensor, |
| ) -> None: |
| """Deterministic last-write-wins update for repeated node ids in a batch.""" |
| for idx in range(len(node_ids)): |
| self.memory[int(node_ids[idx].item())] = values[idx].detach() |
|
|
| def _get_neighbor_tensors( |
| self, node_ids: torch.Tensor |
| ): |
| """ |
| Returns padded (h_nbrs, e_nbrs, dt_nbrs, mask) for a batch of nodes. |
| """ |
| B = len(node_ids) |
| K = self.n_neighbors |
| mem_dim = self.memory_dim |
| e_dim = self.edge_dim |
| device = self.device |
|
|
| h_out = torch.zeros(B, K, mem_dim, device=device) |
| e_out = torch.zeros(B, K, e_dim, device=device) |
| dt_out = torch.zeros(B, K, device=device) |
| mask = torch.zeros(B, K, dtype=torch.bool, device=device) |
|
|
| |
| |
| return h_out, e_out, dt_out, mask |
|
|
| def _fill_neighbor_batch( |
| self, |
| node_ids: torch.Tensor, |
| current_times: torch.Tensor, |
| ): |
| """ |
| Fills neighbor tensors for a batch, using the stored per-node buffers. |
| """ |
| B = len(node_ids) |
| K = self.n_neighbors |
| mem_dim = self.memory_dim |
| e_dim = self.edge_dim |
| device = self.device |
|
|
| h_out = torch.zeros(B, K, mem_dim, device=device) |
| e_out = torch.zeros(B, K, e_dim, device=device) |
| dt_out = torch.zeros(B, K, device=device) |
| mask = torch.zeros(B, K, dtype=torch.bool, device=device) |
|
|
| node_ids_np = node_ids.cpu().numpy() |
| times_np = current_times.cpu().numpy() |
|
|
| for b_idx, (nid, t_cur) in enumerate(zip(node_ids_np, times_np)): |
| buf_t = self.nbr_times[nid] |
| buf_h = self.nbr_h[nid] |
| buf_e = self.nbr_e[nid] |
| n_valid = len(buf_t) |
| if n_valid == 0: |
| continue |
| n_use = min(n_valid, K) |
| |
| for k, i in enumerate(range(n_valid - n_use, n_valid)): |
| h_out[b_idx, k] = buf_h[i] |
| e_out[b_idx, k] = buf_e[i] |
| dt_out[b_idx, k] = max(0.0, float(t_cur) - float(buf_t[i])) |
| mask[b_idx, k] = True |
|
|
| return h_out, e_out, dt_out, mask |
|
|
| def _update_buffers( |
| self, |
| node_ids_np: np.ndarray, |
| times_np: np.ndarray, |
| h_others: torch.Tensor, |
| edge_feats: torch.Tensor, |
| ): |
| """Add events to per-node neighbor buffers (detached).""" |
| for i, nid in enumerate(node_ids_np): |
| self.nbr_times[nid].append(float(times_np[i])) |
| self.nbr_h[nid].append(h_others[i].detach().cpu()) |
| self.nbr_e[nid].append(edge_feats[i].detach().cpu()) |
| |
| if len(self.nbr_times[nid]) > self.n_neighbors: |
| self.nbr_times[nid].pop(0) |
| self.nbr_h[nid].pop(0) |
| self.nbr_e[nid].pop(0) |
|
|
| def process_batch( |
| self, |
| u_ids: torch.Tensor, |
| v_ids: torch.Tensor, |
| times: torch.Tensor, |
| edge_feats: torch.Tensor, |
| compute_grad: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Process a batch of events, update memory, return (logits_u, logits_v) |
| for training (edge-level fraud prediction used only during training). |
| """ |
| device = self.device |
| module = self.module |
|
|
| |
| h_u = self.memory[u_ids].clone() |
| h_v = self.memory[v_ids].clone() |
|
|
| u_np = u_ids.cpu().numpy() |
| v_np = v_ids.cpu().numpy() |
| t_np = times.cpu().numpy() |
|
|
| |
| h_nbrs_u, e_nbrs_u, dt_u, mask_u = self._fill_neighbor_batch(u_ids, times) |
| z_u = module.attend(h_u, h_nbrs_u, e_nbrs_u, dt_u, mask_u) |
| h_u_new = module.update(h_u.detach(), z_u) |
|
|
| |
| h_nbrs_v, e_nbrs_v, dt_v, mask_v = self._fill_neighbor_batch(v_ids, times) |
| z_v = module.attend(h_v, h_nbrs_v, e_nbrs_v, dt_v, mask_v) |
| h_v_new = module.update(h_v.detach(), z_v) |
|
|
| |
| self._write_memory_rows(u_ids, h_u_new) |
| self._write_memory_rows(v_ids, h_v_new) |
|
|
| |
| self._update_buffers(u_np, t_np, h_v_new, edge_feats) |
| self._update_buffers(v_np, t_np, h_u_new, edge_feats) |
|
|
| return h_u_new, h_v_new |
|
|
| def reset(self): |
| self.memory.zero_() |
| self.nbr_times = [[] for _ in range(self.memory.shape[0])] |
| self.nbr_h = [[] for _ in range(self.memory.shape[0])] |
| self.nbr_e = [[] for _ in range(self.memory.shape[0])] |
|
|
|
|
| |
| |
| |
|
|
| class TGATWrapper(TemporalModel): |
| """TGAT wrapped behind the unified TemporalModel interface.""" |
|
|
| def __init__( |
| self, |
| memory_dim: int = 64, |
| time_dim: int = 8, |
| num_heads: int = 4, |
| n_neighbors: int = 10, |
| device: str = "cpu", |
| ): |
| self.memory_dim = memory_dim |
| self.time_dim = time_dim |
| self.num_heads = num_heads |
| self.n_neighbors = n_neighbors |
| self.device = torch.device(device) |
|
|
| self._module: _TGATModule | None = None |
| self._streamer: _TGATStreamer | None = None |
| self._norm_stats: dict | None = None |
| self._n_nodes: int = 0 |
| self._edge_dim: int = 0 |
|
|
| @property |
| def name(self) -> str: |
| return "TGAT" |
|
|
| |
|
|
| def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: |
| df_train = df_train.sort_values("timestamp").reset_index(drop=True) |
|
|
| |
| edge_feats_np = build_edge_features(df_train) |
| edge_dim = edge_feats_np.shape[1] |
| self._edge_dim = edge_dim |
|
|
| |
| ea_mean = edge_feats_np.mean(axis=0) |
| ea_std = edge_feats_np.std(axis=0) + 1e-6 |
| edge_feats_np = (edge_feats_np - ea_mean) / ea_std |
|
|
| |
| t_vals = df_train["timestamp"].values.astype(np.float32) |
| t_min, t_max = t_vals.min(), t_vals.max() |
| t_norm = (t_vals - t_min) / (t_max - t_min + 1e-6) |
|
|
| self._norm_stats = { |
| "ea_mean": ea_mean, "ea_std": ea_std, |
| "t_min": t_min, "t_max": t_max, |
| } |
|
|
| |
| all_nodes = np.union1d( |
| df_train["sender_id"].values, df_train["receiver_id"].values |
| ) |
| n_nodes = int(all_nodes.max()) + 1 |
| self._n_nodes = n_nodes |
|
|
| |
| module = _TGATModule( |
| memory_dim=self.memory_dim, |
| edge_dim=edge_dim, |
| time_dim=self.time_dim, |
| num_heads=self.num_heads, |
| ).to(self.device) |
| self._module = module |
|
|
| streamer = _TGATStreamer( |
| module=module, |
| n_nodes=n_nodes, |
| memory_dim=self.memory_dim, |
| edge_dim=edge_dim, |
| n_neighbors=self.n_neighbors, |
| device=self.device, |
| ) |
| self._streamer = streamer |
|
|
| |
| y = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32) |
| u_ids = torch.tensor(df_train["sender_id"].values, dtype=torch.long) |
| v_ids = torch.tensor(df_train["receiver_id"].values, dtype=torch.long) |
| ef_all = torch.tensor(edge_feats_np, dtype=torch.float32) |
| t_all = torch.tensor(t_norm * 5.0, dtype=torch.float32) |
|
|
| raw_pw = (y == 0).sum() / ((y == 1).sum() + 1e-6) |
| pos_weight = torch.clamp(raw_pw, max=10.0).to(self.device) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
| optimiser = torch.optim.Adam(module.parameters(), lr=1e-3) |
|
|
| |
| |
| edge_classifier = nn.Sequential( |
| nn.Linear(self.memory_dim * 2 + edge_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ).to(self.device) |
| self._edge_clf = edge_classifier |
| optimiser.add_param_group({"params": edge_classifier.parameters()}) |
|
|
| batch_size = 512 |
| N = len(df_train) |
|
|
| for epoch in range(num_epochs): |
| |
| streamer.reset() |
| total_loss = 0.0 |
|
|
| for i in range(0, N, batch_size): |
| j = min(i + batch_size, N) |
| u_b = u_ids[i:j].to(self.device) |
| v_b = v_ids[i:j].to(self.device) |
| t_b = t_all[i:j].to(self.device) |
| ef_b = ef_all[i:j].to(self.device) |
| y_b = y[i:j].to(self.device) |
|
|
| h_u, h_v = streamer.process_batch(u_b, v_b, t_b, ef_b) |
|
|
| edge_in = torch.cat([h_u, h_v, ef_b], dim=-1) |
| logits = edge_classifier(edge_in).squeeze(-1) |
| logits = torch.clamp(logits, -10, 10) |
|
|
| loss = loss_fn(logits, y_b) |
| optimiser.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(module.parameters(), 1.0) |
| optimiser.step() |
|
|
| total_loss += loss.item() |
|
|
| print(f"[TGAT] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}") |
|
|
| |
| self._node_clf = nn.Sequential( |
| nn.Linear(self.memory_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ).to(self.device) |
|
|
| |
|
|
| def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: |
| assert self._streamer is not None, "Call fit() first." |
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) |
|
|
| ns = self._norm_stats |
| ef_np = build_edge_features(df_eval).astype(np.float32) |
| ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] |
|
|
| t_vals = df_eval["timestamp"].values.astype(np.float32) |
| t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) |
|
|
| u_ids = torch.tensor(df_eval["sender_id"].values, dtype=torch.long) |
| v_ids = torch.tensor(df_eval["receiver_id"].values, dtype=torch.long) |
| ef_t = torch.tensor(ef_np, dtype=torch.float32) |
| t_t = torch.tensor(t_norm * 5.0, dtype=torch.float32) |
|
|
| self._module.eval() |
| with torch.no_grad(): |
| batch_size = 512 |
| for i in range(0, len(df_eval), batch_size): |
| j = min(i + batch_size, len(df_eval)) |
| self._streamer.process_batch( |
| u_ids[i:j].to(self.device), |
| v_ids[i:j].to(self.device), |
| t_t[i:j].to(self.device), |
| ef_t[i:j].to(self.device), |
| compute_grad=False, |
| ) |
|
|
| |
| eval_t = torch.tensor( |
| [min(n, self._n_nodes - 1) for n in eval_nodes], |
| dtype=torch.long, device=self.device, |
| ) |
| node_emb = self._streamer.memory[eval_t] |
|
|
| if not hasattr(self, "_node_clf") or self._node_clf is None: |
| self._node_clf = nn.Sequential( |
| nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1) |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| probs = torch.sigmoid(logits).cpu().numpy() |
|
|
| return probs.astype(np.float32) |
|
|
| def extract_prefix_embeddings( |
| self, |
| df_eval: pd.DataFrame, |
| examples: pd.DataFrame, |
| ) -> np.ndarray: |
| assert self._module is not None, "Call fit() first." |
| if examples.empty: |
| return np.zeros((0, self.memory_dim), dtype=np.float32) |
|
|
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy() |
| if "local_event_idx" not in df_eval.columns: |
| df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32) |
|
|
| capture_map: dict[tuple[int, int], list[int]] = {} |
| for ex_idx, row in enumerate(examples.itertuples(index=False)): |
| key = (int(row.sender_id), int(row.eval_local_event_idx)) |
| capture_map.setdefault(key, []).append(ex_idx) |
|
|
| max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1 |
| streamer = _TGATStreamer( |
| module=self._module, |
| n_nodes=max(self._n_nodes, max_seen_id), |
| memory_dim=self.memory_dim, |
| edge_dim=self._edge_dim, |
| n_neighbors=self.n_neighbors, |
| device=self.device, |
| ) |
|
|
| ns = self._norm_stats |
| edge_feats_np = build_edge_features(df_eval).astype(np.float32) |
| edge_feats_np = (edge_feats_np - ns["ea_mean"]) / ns["ea_std"] |
| t_vals = df_eval["timestamp"].to_numpy(dtype=np.float32) |
| t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) * 5.0 |
|
|
| out = np.zeros((len(examples), self.memory_dim), dtype=np.float32) |
| self._module.eval() |
| with torch.no_grad(): |
| for idx, row in enumerate(df_eval.itertuples(index=False)): |
| u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=self.device) |
| v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=self.device) |
| t = torch.tensor([t_norm[idx]], dtype=torch.float32, device=self.device) |
| ef = torch.tensor(edge_feats_np[idx:idx + 1], dtype=torch.float32, device=self.device) |
| streamer.process_batch(u, v, t, ef, compute_grad=False) |
|
|
| key = (int(row.sender_id), int(row.local_event_idx)) |
| if key in capture_map: |
| emb = streamer.memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32) |
| for ex_idx in capture_map[key]: |
| out[ex_idx] = emb |
|
|
| return out |
|
|
| |
|
|
| def reset_memory(self) -> None: |
| if self._streamer is not None: |
| self._streamer.memory.zero_() |
| self._streamer.nbr_times = [[] for _ in range(self._n_nodes)] |
| self._streamer.nbr_h = [[] for _ in range(self._n_nodes)] |
| self._streamer.nbr_e = [[] for _ in range(self._n_nodes)] |
|
|
| |
|
|
| def train_node_classifier( |
| self, |
| eval_nodes: List[int], |
| y_labels: np.ndarray, |
| num_epochs: int = 150, |
| ) -> None: |
| """Fine-tune node classifier on node-level labels from training window.""" |
| device = self.device |
| eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| node_emb = self._streamer.memory[eval_t].detach() |
| y = torch.tensor(y_labels, dtype=torch.float32, device=device) |
|
|
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) |
| opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) |
|
|
| self._node_clf.train() |
| for _ in range(num_epochs): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| self._node_clf.eval() |
|
|