""" models/tgn_wrapper.py ===================== Wraps the existing src/tgn/ pipeline behind the TemporalModel interface. Architecture (unchanged from src/tgn/model.py): - GRU-based memory module - Message MLP (memory × 2 + edge + time → memory) - Node classifier head: memory + static_feat → fraud prob """ from __future__ import annotations import copy from typing import List import numpy as np import pandas as pd import torch from models.base import TemporalModel from src.graph.dataset_builder import build_graph_dataset from src.graph.graph_builder import build_edge_features from src.tgn.memory import Memory from src.tgn.model import TGN from src.tgn.time_encoding import TimeEncoding from src.tgn.train import train_tgn class TGNWrapper(TemporalModel): """TGN with GRU memory, wrapped behind the unified TemporalModel interface.""" def __init__( self, memory_dim: int = 64, time_dim: int = 16, hidden_dim: int = 128, device: str = "cpu", ): self.memory_dim = memory_dim self.time_dim = time_dim self.hidden_dim = hidden_dim self.device = torch.device(device) # filled by fit() self._model: TGN | None = None self._memory: Memory | None = None self._time_encoder: TimeEncoding | None = None self._norm_stats: dict | None = None self._num_nodes: int = 0 self._users: pd.DataFrame | None = None self._node_head_fitted = False @property def name(self) -> str: return "TGN" # ------------------------------------------------------------------ # def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: df_train = df_train.sort_values("timestamp").reset_index(drop=True) # build_graph_dataset expects a users DataFrame; derive a minimal one users = _make_users_df(df_train) self._users = users graph_data = build_graph_dataset(df_train, users) # Override train_mask to use ALL training events graph_data["train_mask"] = np.ones(len(df_train), dtype=bool) self._model, self._memory, self._time_encoder, self._norm_stats = train_tgn( graph_data, num_epochs=num_epochs ) self._num_nodes = self._memory.memory.shape[0] # ------------------------------------------------------------------ # def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: assert self._model is not None, "Call fit() first." df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) device = self.device model = self._model memory = self._memory time_encoder = self._time_encoder ns = self._norm_stats # Warm-up: pass eval events through memory (no label access) edge_index = torch.tensor( np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]), dtype=torch.long, ) edge_attr = torch.tensor( build_edge_features(df_eval), dtype=torch.float32 ) edge_attr = (edge_attr - ns["ea_mean"]) / ns["ea_std"] timestamps = torch.tensor(df_eval["timestamp"].values, dtype=torch.float32) timestamps = (timestamps - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) batch_size = 1024 model.eval() with torch.no_grad(): for i in range(0, len(df_eval), batch_size): ids = range(i, min(i + batch_size, len(df_eval))) u = edge_index[0, ids].to(device) v = edge_index[1, ids].to(device) ef = edge_attr[ids].to(device) t = timestamps[ids].to(device) * 5.0 time_enc = time_encoder(t) h_u = memory.get(u) h_v = memory.get(v) msg = model.compute_message(h_u, h_v, ef, time_enc) node_ids = torch.cat([u, v]) messages = torch.cat([msg, msg]) unique_nodes, inv = torch.unique(node_ids, return_inverse=True) agg = torch.zeros_like(memory.memory[unique_nodes]) agg.index_add_(0, inv, messages) counts = torch.bincount(inv).unsqueeze(1) memory.update(unique_nodes, agg / counts) # Score eval nodes (clamp to valid range for OOD nodes) eval_nodes_clamped = [min(n, self._num_nodes - 1) for n in eval_nodes] eval_nodes_t = torch.tensor(eval_nodes_clamped, dtype=torch.long, device=device) node_emb = memory.memory[eval_nodes_t].clone() x_zeros = torch.zeros(len(eval_nodes), ns["x"].shape[1], device=device) model.eval() with torch.no_grad(): combined = torch.cat([node_emb, x_zeros], dim=1) probs = torch.sigmoid( model.node_classifier(combined).squeeze(-1) ).cpu().numpy() return probs.astype(np.float32) def extract_prefix_embeddings( self, df_eval: pd.DataFrame, examples: pd.DataFrame, ) -> np.ndarray: assert self._model is not None, "Call fit() first." if examples.empty: return np.zeros((0, self.memory_dim), dtype=np.float32) df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy() if "local_event_idx" not in df_eval.columns: df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32) capture_map: dict[tuple[int, int], list[int]] = {} for ex_idx, row in enumerate(examples.itertuples(index=False)): key = (int(row.sender_id), int(row.eval_local_event_idx)) capture_map.setdefault(key, []).append(ex_idx) max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1 num_nodes = max(self._num_nodes, max_seen_id) device = self.device model = self._model time_encoder = self._time_encoder ns = self._norm_stats memory = Memory(num_nodes, memory_dim=self.memory_dim, device=device) ea_mean = ns["ea_mean"].detach().cpu().numpy() if isinstance(ns["ea_mean"], torch.Tensor) else np.asarray(ns["ea_mean"], dtype=np.float32) ea_std = ns["ea_std"].detach().cpu().numpy() if isinstance(ns["ea_std"], torch.Tensor) else np.asarray(ns["ea_std"], dtype=np.float32) t_min = float(ns["t_min"].item()) if isinstance(ns["t_min"], torch.Tensor) else float(ns["t_min"]) t_max = float(ns["t_max"].item()) if isinstance(ns["t_max"], torch.Tensor) else float(ns["t_max"]) edge_attr = build_edge_features(df_eval).astype(np.float32) edge_attr = (edge_attr - ea_mean) / ea_std timestamps = df_eval["timestamp"].to_numpy(dtype=np.float32) timestamps = (timestamps - t_min) / (t_max - t_min + 1e-6) timestamps = timestamps * 5.0 out = np.zeros((len(examples), self.memory_dim), dtype=np.float32) model.eval() with torch.no_grad(): for idx, row in enumerate(df_eval.itertuples(index=False)): u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=device) v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=device) ef = torch.tensor(edge_attr[idx:idx + 1], dtype=torch.float32, device=device) t = torch.tensor([timestamps[idx]], dtype=torch.float32, device=device) time_enc = time_encoder(t) h_u = memory.get(u) h_v = memory.get(v) msg = model.compute_message(h_u, h_v, ef, time_enc) node_ids = torch.cat([u, v]) messages = torch.cat([msg, msg], dim=0) unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True) agg_msg = torch.zeros((len(unique_nodes), self.memory_dim), device=device) agg_msg.index_add_(0, inverse_idx, messages) counts = torch.bincount(inverse_idx).unsqueeze(1).float() memory.update(unique_nodes, agg_msg / counts) key = (int(row.sender_id), int(row.local_event_idx)) if key in capture_map: emb = memory.memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32) for ex_idx in capture_map[key]: out[ex_idx] = emb return out # ------------------------------------------------------------------ # def reset_memory(self) -> None: if self._memory is not None: self._memory.memory.zero_() # ------------------------------------------------------------------ # def _train_node_head( self, eval_nodes: List[int], y_train: np.ndarray, num_epochs: int = 100, ) -> None: """Fine-tune the node classifier head on training labels.""" assert self._model is not None device = self.device model = self._model memory = self._memory eval_nodes_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) x = torch.zeros(len(eval_nodes), self._norm_stats["x"].shape[1], device=device) y = torch.tensor(y_train, dtype=torch.float32, device=device) saw_grad = False for p in model.parameters(): p.requires_grad = False for p in model.node_classifier.parameters(): p.requires_grad = True opt = torch.optim.Adam(model.node_classifier.parameters(), lr=1e-3) pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pw) model.train() for _ in range(num_epochs): node_emb = memory.memory[eval_nodes_t].detach() combined = torch.cat([node_emb, x], dim=1) logits = model.node_classifier(combined).squeeze(-1) loss = loss_fn(logits, y) opt.zero_grad() loss.backward() saw_grad = saw_grad or any( p.grad is not None and torch.isfinite(p.grad).all() for p in model.node_classifier.parameters() ) opt.step() for p in model.parameters(): p.requires_grad = True assert saw_grad, "TGN node classifier did not receive gradients." self._node_head_fitted = True def train_node_classifier( self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 100, ) -> None: self._train_node_head(eval_nodes, y_labels, num_epochs=num_epochs) # ------------------------------------------------------------------ # # Helpers # # ------------------------------------------------------------------ # def _make_users_df(df: pd.DataFrame) -> pd.DataFrame: """Create a minimal users DataFrame from sender_ids in df.""" max_id = int(max(df["sender_id"].max(), df["receiver_id"].max())) return pd.DataFrame({"user_id": np.arange(max_id + 1, dtype=np.int64)})