""" models/static_gnn.py ==================== Static GNN Baseline: GraphSAGE with Snapshot Batching Architecture ------------ Events are binned into N time-snapshots (equal-count bins). For each snapshot: - Build a static homogeneous graph from the events in that bin - Run 2-layer GraphSAGE to produce node embeddings - Aggregate per-node embeddings across all snapshots (mean pooling) A node classifier head is trained on the pooled embeddings. This model has NO temporal memory between snapshots. It is the strongest "static" baseline: it sees the full graph structure but cannot reason about the ordering of events within or across snapshots. Note: SAGEConv is used (from torch_geometric). Falls back gracefully when a node has no edges in a snapshot (embedding stays at zero for that snapshot). """ from __future__ import annotations from typing import List import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import SAGEConv from models.base import TemporalModel from src.graph.graph_builder import build_edge_features _BLOCKED_COLS = frozenset({ "motif_hit_count", "motif_source", "trigger_event_idx", "label_event_idx", "label_delay", "is_fallback_label", "fraud_source", "twin_role", "twin_label", "twin_pair_id", "template_id", "dynamic_fraud_state", "motif_chain_state", "motif_strength", }) # ------------------------------------------------------------------ # # Core GraphSAGE nn.Module # # ------------------------------------------------------------------ # class _SAGEEncoder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.conv1 = SAGEConv(in_dim, hidden_dim) self.conv2 = SAGEConv(hidden_dim, hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) self.norm2 = nn.LayerNorm(hidden_dim) def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: h = F.relu(self.norm1(self.conv1(x, edge_index))) h = self.norm2(self.conv2(h, edge_index)) return h # ------------------------------------------------------------------ # # StaticGNNWrapper (TemporalModel interface) # # ------------------------------------------------------------------ # class StaticGNNWrapper(TemporalModel): """GraphSAGE with time-snapshot aggregation. No temporal memory.""" def __init__( self, hidden_dim: int = 64, n_snapshots: int = 10, device: str = "cpu", ): self.hidden_dim = hidden_dim self.n_snapshots = n_snapshots self.device = torch.device(device) self._encoder: _SAGEEncoder | None = None self._node_clf: nn.Sequential | None = None self._norm_stats: dict | None = None self._n_nodes: int = 0 self._node_emb_agg: torch.Tensor | None = None # (n_nodes, hidden_dim) self._in_dim: int = 0 @property def name(self) -> str: return "StaticGNN" @property def is_temporal(self) -> bool: return False # ------------------------------------------------------------------ # def _build_snapshots( self, df: pd.DataFrame, ef_np: np.ndarray ) -> List[tuple]: """ Returns list of (edge_index_t, edge_attr_t, src_nodes, dst_nodes) for each snapshot bin. """ df = df.sort_values("timestamp").reset_index(drop=True) n = len(df) bin_size = max(1, n // self.n_snapshots) snapshots = [] for b in range(self.n_snapshots): lo = b * bin_size hi = lo + bin_size if b < self.n_snapshots - 1 else n sub_u = df["sender_id"].values[lo:hi].astype(np.int64) sub_v = df["receiver_id"].values[lo:hi].astype(np.int64) sub_e = ef_np[lo:hi] edge_index = torch.tensor(np.vstack([sub_u, sub_v]), dtype=torch.long) edge_attr = torch.tensor(sub_e, dtype=torch.float32) snapshots.append((edge_index, edge_attr, sub_u, sub_v)) return snapshots # ------------------------------------------------------------------ # def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: leaked = _BLOCKED_COLS & set(df_train.columns) assert not leaked, f"Oracle columns leaked into StaticGNN.fit(): {leaked}" df_train = df_train.sort_values("timestamp").reset_index(drop=True) ef_np = build_edge_features(df_train).astype(np.float32) edge_dim = ef_np.shape[1] self._in_dim = edge_dim # node features are mean-pooled edge features per snapshot ea_mean = ef_np.mean(axis=0) ea_std = ef_np.std(axis=0) + 1e-6 ef_np = (ef_np - ea_mean) / ea_std self._norm_stats = {"ea_mean": ea_mean, "ea_std": ea_std} all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values) n_nodes = int(all_ids.max()) + 1 self._n_nodes = n_nodes device = self.device # Node input features: mean of outgoing edge features per node (snapshot-level) encoder = _SAGEEncoder(in_dim=edge_dim, hidden_dim=self.hidden_dim).to(device) self._encoder = encoder node_clf = nn.Sequential( nn.Linear(self.hidden_dim, 64), nn.ReLU(), nn.Linear(64, 1), ).to(device) self._node_clf = node_clf # Build snapshots snapshots = self._build_snapshots(df_train, ef_np) y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32) raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6) pos_weight = torch.clamp(raw_pw, max=10.0).to(device) loss_fn_edge = nn.BCEWithLogitsLoss(pos_weight=pos_weight) opt = torch.optim.Adam( list(encoder.parameters()) + list(node_clf.parameters()), lr=1e-3, ) # Build per-node input feature matrix: aggregate edge features to nodes node_feat = self._build_node_feat(df_train, ef_np, n_nodes) x_full = torch.tensor(node_feat, dtype=torch.float32, device=device) for epoch in range(num_epochs): encoder.train() node_clf.train() total_loss = 0.0 emb_accum = torch.zeros(n_nodes, self.hidden_dim, device=device) snap_cnt = torch.zeros(n_nodes, dtype=torch.float32, device=device) for snap_idx, (edge_index, edge_attr, src_np, _) in enumerate(snapshots): edge_index = edge_index.to(device) edge_attr = edge_attr.to(device) # Get snapshot slice indices in original df n = len(df_train) bin_size = max(1, n // self.n_snapshots) lo = snap_idx * bin_size hi = lo + bin_size if snap_idx < self.n_snapshots - 1 else n y_snap = y_all[lo:hi].to(device) h = encoder(x_full, edge_index) # (n_nodes, hidden_dim) # Edge-level fraud loss on this snapshot src_t = edge_index[0] dst_t = edge_index[1] h_src = h[src_t] h_dst = h[dst_t] edge_logits = (h_src * h_dst).sum(dim=-1) # dot-product score edge_logits = torch.clamp(edge_logits, -10, 10) loss = loss_fn_edge(edge_logits, y_snap) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) opt.step() total_loss += loss.item() # Accumulate node embeddings across snapshots (detached) with torch.no_grad(): emb_accum += h.detach() snap_cnt += 1.0 # Pooled node embedding emb_pooled = emb_accum / snap_cnt.unsqueeze(1).clamp(min=1.0) self._node_emb_agg = emb_pooled.clone() print(f"[StaticGNN] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}") # Freeze encoder; train node classifier on pooled embeddings self._train_node_clf(df_train) # ------------------------------------------------------------------ # def _compute_prefix_embeddings(self, df_prefix: pd.DataFrame) -> torch.Tensor: """Compute node embeddings for a causal prefix graph.""" device = self.device ns = self._norm_stats df_prefix = df_prefix.sort_values("timestamp").reset_index(drop=True) ef_np = build_edge_features(df_prefix).astype(np.float32) ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] all_ids = np.union1d(df_prefix["sender_id"].values, df_prefix["receiver_id"].values) n_nodes = max(int(all_ids.max()) + 1, self._n_nodes) node_feat = self._build_node_feat(df_prefix, ef_np, n_nodes) x = torch.tensor(node_feat, dtype=torch.float32, device=device) edge_index = torch.tensor( np.vstack([df_prefix["sender_id"].values, df_prefix["receiver_id"].values]), dtype=torch.long, device=device, ) self._encoder.eval() with torch.no_grad(): return self._encoder(x, edge_index) # ------------------------------------------------------------------ # def _build_node_feat( self, df: pd.DataFrame, ef_np: np.ndarray, n_nodes: int ) -> np.ndarray: """Aggregate edge features to sender nodes (mean).""" feat = np.zeros((n_nodes, ef_np.shape[1]), dtype=np.float32) cnt = np.zeros(n_nodes, dtype=np.float32) sids = df["sender_id"].values.astype(np.int64) np.add.at(feat, sids, ef_np) np.add.at(cnt, sids, 1.0) cnt = np.maximum(cnt, 1.0) return feat / cnt[:, None] def _train_node_clf(self, df_train: pd.DataFrame, num_epochs: int = 150) -> None: """Fine-tune node classifier on node-level fraud labels (training split).""" device = self.device emb = self._node_emb_agg # (n_nodes, hidden_dim) all_nodes = sorted(df_train["sender_id"].unique()) eval_t = torch.tensor(all_nodes, dtype=torch.long, device=device) # Build node-level labels: any fraud in the training window? y_map = df_train.groupby("sender_id")["is_fraud"].max() y_np = np.array([y_map.get(n, 0) for n in all_nodes], dtype=np.float32) y = torch.tensor(y_np, device=device) node_emb = emb[eval_t].detach() pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) self._node_clf.train() for _ in range(num_epochs): logits = self._node_clf(node_emb).squeeze(-1) loss = loss_fn(logits, y) opt.zero_grad() loss.backward() opt.step() self._node_clf.eval() # ------------------------------------------------------------------ # def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: assert self._encoder is not None, "Call fit() first." leaked = _BLOCKED_COLS & set(df_eval.columns) assert not leaked, f"Oracle columns leaked into StaticGNN.predict(): {leaked}" device = self.device ns = self._norm_stats # Build node embeddings from eval graph (no memory — static) df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) ef_np = build_edge_features(df_eval).astype(np.float32) ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] all_ids = np.union1d(df_eval["sender_id"].values, df_eval["receiver_id"].values) n_nodes = max(int(all_ids.max()) + 1, self._n_nodes) node_feat = self._build_node_feat(df_eval, ef_np, n_nodes) x = torch.tensor(node_feat, dtype=torch.float32, device=device) edge_index = torch.tensor( np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]), dtype=torch.long, device=device, ) self._encoder.eval() with torch.no_grad(): h = self._encoder(x, edge_index) # (n_nodes, hidden_dim) eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) node_emb = h[eval_t] with torch.no_grad(): probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy() return probs.astype(np.float32) # ------------------------------------------------------------------ # def reset_memory(self) -> None: """No-op: StaticGNN has no temporal memory.""" pass # ------------------------------------------------------------------ # def train_node_classifier( self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150 ) -> None: """Re-train node classifier with fresh labels (for horizon sweep).""" device = self.device eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) node_emb = self._node_emb_agg[eval_t].detach() y = torch.tensor(y_labels, dtype=torch.float32, device=device) pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) self._node_clf.train() for _ in range(num_epochs): logits = self._node_clf(node_emb).squeeze(-1) loss = loss_fn(logits, y) opt.zero_grad() loss.backward() opt.step() self._node_clf.eval() def train_node_classifier_on_prefix( self, df_prefix: pd.DataFrame, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150, ) -> None: """Train the node classifier on embeddings computed from a causal prefix.""" device = self.device prefix_emb = self._compute_prefix_embeddings(df_prefix) eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) node_emb = prefix_emb[eval_t].detach() y = torch.tensor(y_labels, dtype=torch.float32, device=device) pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) self._node_clf.train() for _ in range(num_epochs): logits = self._node_clf(node_emb).squeeze(-1) loss = loss_fn(logits, y) opt.zero_grad() loss.backward() opt.step() self._node_clf.eval()