| """ |
| models/tgn_wrapper.py |
| ===================== |
| Wraps the existing src/tgn/ pipeline behind the TemporalModel interface. |
| |
| Architecture (unchanged from src/tgn/model.py): |
| - GRU-based memory module |
| - Message MLP (memory × 2 + edge + time → memory) |
| - Node classifier head: memory + static_feat → fraud prob |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from models.base import TemporalModel |
| from src.graph.dataset_builder import build_graph_dataset |
| from src.graph.graph_builder import build_edge_features |
| from src.tgn.memory import Memory |
| from src.tgn.model import TGN |
| from src.tgn.time_encoding import TimeEncoding |
| from src.tgn.train import train_tgn |
|
|
|
|
| class TGNWrapper(TemporalModel): |
| """TGN with GRU memory, wrapped behind the unified TemporalModel interface.""" |
|
|
| def __init__( |
| self, |
| memory_dim: int = 64, |
| time_dim: int = 16, |
| hidden_dim: int = 128, |
| device: str = "cpu", |
| ): |
| self.memory_dim = memory_dim |
| self.time_dim = time_dim |
| self.hidden_dim = hidden_dim |
| self.device = torch.device(device) |
|
|
| |
| self._model: TGN | None = None |
| self._memory: Memory | None = None |
| self._time_encoder: TimeEncoding | None = None |
| self._norm_stats: dict | None = None |
| self._num_nodes: int = 0 |
| self._users: pd.DataFrame | None = None |
| self._node_head_fitted = False |
|
|
| @property |
| def name(self) -> str: |
| return "TGN" |
|
|
| |
|
|
| def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: |
| df_train = df_train.sort_values("timestamp").reset_index(drop=True) |
|
|
| |
| users = _make_users_df(df_train) |
| self._users = users |
|
|
| graph_data = build_graph_dataset(df_train, users) |
| |
| graph_data["train_mask"] = np.ones(len(df_train), dtype=bool) |
|
|
| self._model, self._memory, self._time_encoder, self._norm_stats = train_tgn( |
| graph_data, num_epochs=num_epochs |
| ) |
| self._num_nodes = self._memory.memory.shape[0] |
|
|
| |
|
|
| def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: |
| assert self._model is not None, "Call fit() first." |
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) |
|
|
| device = self.device |
| model = self._model |
| memory = self._memory |
| time_encoder = self._time_encoder |
| ns = self._norm_stats |
|
|
| |
| edge_index = torch.tensor( |
| np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]), |
| dtype=torch.long, |
| ) |
| edge_attr = torch.tensor( |
| build_edge_features(df_eval), dtype=torch.float32 |
| ) |
| edge_attr = (edge_attr - ns["ea_mean"]) / ns["ea_std"] |
|
|
| timestamps = torch.tensor(df_eval["timestamp"].values, dtype=torch.float32) |
| timestamps = (timestamps - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) |
|
|
| batch_size = 1024 |
| model.eval() |
| with torch.no_grad(): |
| for i in range(0, len(df_eval), batch_size): |
| ids = range(i, min(i + batch_size, len(df_eval))) |
| u = edge_index[0, ids].to(device) |
| v = edge_index[1, ids].to(device) |
| ef = edge_attr[ids].to(device) |
| t = timestamps[ids].to(device) * 5.0 |
|
|
| time_enc = time_encoder(t) |
| h_u = memory.get(u) |
| h_v = memory.get(v) |
| msg = model.compute_message(h_u, h_v, ef, time_enc) |
|
|
| node_ids = torch.cat([u, v]) |
| messages = torch.cat([msg, msg]) |
| unique_nodes, inv = torch.unique(node_ids, return_inverse=True) |
| agg = torch.zeros_like(memory.memory[unique_nodes]) |
| agg.index_add_(0, inv, messages) |
| counts = torch.bincount(inv).unsqueeze(1) |
| memory.update(unique_nodes, agg / counts) |
|
|
| |
| eval_nodes_clamped = [min(n, self._num_nodes - 1) for n in eval_nodes] |
| eval_nodes_t = torch.tensor(eval_nodes_clamped, dtype=torch.long, device=device) |
| node_emb = memory.memory[eval_nodes_t].clone() |
| x_zeros = torch.zeros(len(eval_nodes), ns["x"].shape[1], device=device) |
|
|
| model.eval() |
| with torch.no_grad(): |
| combined = torch.cat([node_emb, x_zeros], dim=1) |
| probs = torch.sigmoid( |
| model.node_classifier(combined).squeeze(-1) |
| ).cpu().numpy() |
|
|
| return probs.astype(np.float32) |
|
|
| def extract_prefix_embeddings( |
| self, |
| df_eval: pd.DataFrame, |
| examples: pd.DataFrame, |
| ) -> np.ndarray: |
| assert self._model is not None, "Call fit() first." |
| if examples.empty: |
| return np.zeros((0, self.memory_dim), dtype=np.float32) |
|
|
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy() |
| if "local_event_idx" not in df_eval.columns: |
| df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32) |
|
|
| capture_map: dict[tuple[int, int], list[int]] = {} |
| for ex_idx, row in enumerate(examples.itertuples(index=False)): |
| key = (int(row.sender_id), int(row.eval_local_event_idx)) |
| capture_map.setdefault(key, []).append(ex_idx) |
|
|
| max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1 |
| num_nodes = max(self._num_nodes, max_seen_id) |
| device = self.device |
| model = self._model |
| time_encoder = self._time_encoder |
| ns = self._norm_stats |
| memory = Memory(num_nodes, memory_dim=self.memory_dim, device=device) |
|
|
| ea_mean = ns["ea_mean"].detach().cpu().numpy() if isinstance(ns["ea_mean"], torch.Tensor) else np.asarray(ns["ea_mean"], dtype=np.float32) |
| ea_std = ns["ea_std"].detach().cpu().numpy() if isinstance(ns["ea_std"], torch.Tensor) else np.asarray(ns["ea_std"], dtype=np.float32) |
| t_min = float(ns["t_min"].item()) if isinstance(ns["t_min"], torch.Tensor) else float(ns["t_min"]) |
| t_max = float(ns["t_max"].item()) if isinstance(ns["t_max"], torch.Tensor) else float(ns["t_max"]) |
|
|
| edge_attr = build_edge_features(df_eval).astype(np.float32) |
| edge_attr = (edge_attr - ea_mean) / ea_std |
| timestamps = df_eval["timestamp"].to_numpy(dtype=np.float32) |
| timestamps = (timestamps - t_min) / (t_max - t_min + 1e-6) |
| timestamps = timestamps * 5.0 |
|
|
| out = np.zeros((len(examples), self.memory_dim), dtype=np.float32) |
|
|
| model.eval() |
| with torch.no_grad(): |
| for idx, row in enumerate(df_eval.itertuples(index=False)): |
| u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=device) |
| v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=device) |
| ef = torch.tensor(edge_attr[idx:idx + 1], dtype=torch.float32, device=device) |
| t = torch.tensor([timestamps[idx]], dtype=torch.float32, device=device) |
|
|
| time_enc = time_encoder(t) |
| h_u = memory.get(u) |
| h_v = memory.get(v) |
| msg = model.compute_message(h_u, h_v, ef, time_enc) |
|
|
| node_ids = torch.cat([u, v]) |
| messages = torch.cat([msg, msg], dim=0) |
| unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) |
| agg_msg = torch.zeros((len(unique_nodes), self.memory_dim), device=device) |
| agg_msg.index_add_(0, inverse_idx, messages) |
| counts = torch.bincount(inverse_idx).unsqueeze(1).float() |
| memory.update(unique_nodes, agg_msg / counts) |
|
|
| key = (int(row.sender_id), int(row.local_event_idx)) |
| if key in capture_map: |
| emb = memory.memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32) |
| for ex_idx in capture_map[key]: |
| out[ex_idx] = emb |
|
|
| return out |
|
|
| |
|
|
| def reset_memory(self) -> None: |
| if self._memory is not None: |
| self._memory.memory.zero_() |
|
|
| |
|
|
| def _train_node_head( |
| self, |
| eval_nodes: List[int], |
| y_train: np.ndarray, |
| num_epochs: int = 100, |
| ) -> None: |
| """Fine-tune the node classifier head on training labels.""" |
| assert self._model is not None |
| device = self.device |
| model = self._model |
| memory = self._memory |
|
|
| eval_nodes_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| x = torch.zeros(len(eval_nodes), self._norm_stats["x"].shape[1], device=device) |
| y = torch.tensor(y_train, dtype=torch.float32, device=device) |
| saw_grad = False |
|
|
| for p in model.parameters(): |
| p.requires_grad = False |
| for p in model.node_classifier.parameters(): |
| p.requires_grad = True |
|
|
| opt = torch.optim.Adam(model.node_classifier.parameters(), lr=1e-3) |
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pw) |
|
|
| model.train() |
| for _ in range(num_epochs): |
| node_emb = memory.memory[eval_nodes_t].detach() |
| combined = torch.cat([node_emb, x], dim=1) |
| logits = model.node_classifier(combined).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| saw_grad = saw_grad or any( |
| p.grad is not None and torch.isfinite(p.grad).all() |
| for p in model.node_classifier.parameters() |
| ) |
| opt.step() |
|
|
| for p in model.parameters(): |
| p.requires_grad = True |
|
|
| assert saw_grad, "TGN node classifier did not receive gradients." |
| self._node_head_fitted = True |
|
|
| def train_node_classifier( |
| self, |
| eval_nodes: List[int], |
| y_labels: np.ndarray, |
| num_epochs: int = 100, |
| ) -> None: |
| self._train_node_head(eval_nodes, y_labels, num_epochs=num_epochs) |
|
|
|
|
| |
| |
| |
|
|
| def _make_users_df(df: pd.DataFrame) -> pd.DataFrame: |
| """Create a minimal users DataFrame from sender_ids in df.""" |
| max_id = int(max(df["sender_id"].max(), df["receiver_id"].max())) |
| return pd.DataFrame({"user_id": np.arange(max_id + 1, dtype=np.int64)}) |
|
|