| """ |
| models/static_gnn.py |
| ==================== |
| Static GNN Baseline: GraphSAGE with Snapshot Batching |
| |
| Architecture |
| ------------ |
| Events are binned into N time-snapshots (equal-count bins). |
| For each snapshot: |
| - Build a static homogeneous graph from the events in that bin |
| - Run 2-layer GraphSAGE to produce node embeddings |
| - Aggregate per-node embeddings across all snapshots (mean pooling) |
| A node classifier head is trained on the pooled embeddings. |
| |
| This model has NO temporal memory between snapshots. It is the strongest |
| "static" baseline: it sees the full graph structure but cannot reason about |
| the ordering of events within or across snapshots. |
| |
| Note: SAGEConv is used (from torch_geometric). Falls back gracefully when |
| a node has no edges in a snapshot (embedding stays at zero for that snapshot). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import SAGEConv |
|
|
| from models.base import TemporalModel |
| from src.graph.graph_builder import build_edge_features |
|
|
| _BLOCKED_COLS = frozenset({ |
| "motif_hit_count", "motif_source", "trigger_event_idx", "label_event_idx", |
| "label_delay", "is_fallback_label", "fraud_source", |
| "twin_role", "twin_label", "twin_pair_id", "template_id", |
| "dynamic_fraud_state", "motif_chain_state", "motif_strength", |
| }) |
|
|
|
|
|
|
| |
| |
| |
|
|
| class _SAGEEncoder(nn.Module): |
| def __init__(self, in_dim: int, hidden_dim: int): |
| super().__init__() |
| self.conv1 = SAGEConv(in_dim, hidden_dim) |
| self.conv2 = SAGEConv(hidden_dim, hidden_dim) |
| self.norm1 = nn.LayerNorm(hidden_dim) |
| self.norm2 = nn.LayerNorm(hidden_dim) |
|
|
| def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: |
| h = F.relu(self.norm1(self.conv1(x, edge_index))) |
| h = self.norm2(self.conv2(h, edge_index)) |
| return h |
|
|
|
|
| |
| |
| |
|
|
| class StaticGNNWrapper(TemporalModel): |
| """GraphSAGE with time-snapshot aggregation. No temporal memory.""" |
|
|
| def __init__( |
| self, |
| hidden_dim: int = 64, |
| n_snapshots: int = 10, |
| device: str = "cpu", |
| ): |
| self.hidden_dim = hidden_dim |
| self.n_snapshots = n_snapshots |
| self.device = torch.device(device) |
|
|
| self._encoder: _SAGEEncoder | None = None |
| self._node_clf: nn.Sequential | None = None |
| self._norm_stats: dict | None = None |
| self._n_nodes: int = 0 |
| self._node_emb_agg: torch.Tensor | None = None |
| self._in_dim: int = 0 |
|
|
| @property |
| def name(self) -> str: |
| return "StaticGNN" |
|
|
| @property |
| def is_temporal(self) -> bool: |
| return False |
|
|
| |
|
|
| def _build_snapshots( |
| self, df: pd.DataFrame, ef_np: np.ndarray |
| ) -> List[tuple]: |
| """ |
| Returns list of (edge_index_t, edge_attr_t, src_nodes, dst_nodes) |
| for each snapshot bin. |
| """ |
| df = df.sort_values("timestamp").reset_index(drop=True) |
| n = len(df) |
| bin_size = max(1, n // self.n_snapshots) |
|
|
| snapshots = [] |
| for b in range(self.n_snapshots): |
| lo = b * bin_size |
| hi = lo + bin_size if b < self.n_snapshots - 1 else n |
| sub_u = df["sender_id"].values[lo:hi].astype(np.int64) |
| sub_v = df["receiver_id"].values[lo:hi].astype(np.int64) |
| sub_e = ef_np[lo:hi] |
|
|
| edge_index = torch.tensor(np.vstack([sub_u, sub_v]), dtype=torch.long) |
| edge_attr = torch.tensor(sub_e, dtype=torch.float32) |
| snapshots.append((edge_index, edge_attr, sub_u, sub_v)) |
| return snapshots |
|
|
| |
|
|
| def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: |
| leaked = _BLOCKED_COLS & set(df_train.columns) |
| assert not leaked, f"Oracle columns leaked into StaticGNN.fit(): {leaked}" |
| df_train = df_train.sort_values("timestamp").reset_index(drop=True) |
|
|
|
|
| ef_np = build_edge_features(df_train).astype(np.float32) |
| edge_dim = ef_np.shape[1] |
| self._in_dim = edge_dim |
|
|
| ea_mean = ef_np.mean(axis=0) |
| ea_std = ef_np.std(axis=0) + 1e-6 |
| ef_np = (ef_np - ea_mean) / ea_std |
| self._norm_stats = {"ea_mean": ea_mean, "ea_std": ea_std} |
|
|
| all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values) |
| n_nodes = int(all_ids.max()) + 1 |
| self._n_nodes = n_nodes |
|
|
| device = self.device |
|
|
| |
| encoder = _SAGEEncoder(in_dim=edge_dim, hidden_dim=self.hidden_dim).to(device) |
| self._encoder = encoder |
|
|
| node_clf = nn.Sequential( |
| nn.Linear(self.hidden_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ).to(device) |
| self._node_clf = node_clf |
|
|
| |
| snapshots = self._build_snapshots(df_train, ef_np) |
|
|
| y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32) |
| raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6) |
| pos_weight = torch.clamp(raw_pw, max=10.0).to(device) |
|
|
| loss_fn_edge = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
| opt = torch.optim.Adam( |
| list(encoder.parameters()) + list(node_clf.parameters()), |
| lr=1e-3, |
| ) |
|
|
| |
| node_feat = self._build_node_feat(df_train, ef_np, n_nodes) |
| x_full = torch.tensor(node_feat, dtype=torch.float32, device=device) |
|
|
| for epoch in range(num_epochs): |
| encoder.train() |
| node_clf.train() |
| total_loss = 0.0 |
| emb_accum = torch.zeros(n_nodes, self.hidden_dim, device=device) |
| snap_cnt = torch.zeros(n_nodes, dtype=torch.float32, device=device) |
|
|
| for snap_idx, (edge_index, edge_attr, src_np, _) in enumerate(snapshots): |
| edge_index = edge_index.to(device) |
| edge_attr = edge_attr.to(device) |
|
|
| |
| n = len(df_train) |
| bin_size = max(1, n // self.n_snapshots) |
| lo = snap_idx * bin_size |
| hi = lo + bin_size if snap_idx < self.n_snapshots - 1 else n |
| y_snap = y_all[lo:hi].to(device) |
|
|
| h = encoder(x_full, edge_index) |
|
|
| |
| src_t = edge_index[0] |
| dst_t = edge_index[1] |
| h_src = h[src_t] |
| h_dst = h[dst_t] |
| edge_logits = (h_src * h_dst).sum(dim=-1) |
| edge_logits = torch.clamp(edge_logits, -10, 10) |
| loss = loss_fn_edge(edge_logits, y_snap) |
|
|
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) |
| opt.step() |
| total_loss += loss.item() |
|
|
| |
| with torch.no_grad(): |
| emb_accum += h.detach() |
| snap_cnt += 1.0 |
|
|
| |
| emb_pooled = emb_accum / snap_cnt.unsqueeze(1).clamp(min=1.0) |
| self._node_emb_agg = emb_pooled.clone() |
|
|
| print(f"[StaticGNN] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}") |
|
|
| |
| self._train_node_clf(df_train) |
|
|
| |
|
|
| def _compute_prefix_embeddings(self, df_prefix: pd.DataFrame) -> torch.Tensor: |
| """Compute node embeddings for a causal prefix graph.""" |
| device = self.device |
| ns = self._norm_stats |
|
|
| df_prefix = df_prefix.sort_values("timestamp").reset_index(drop=True) |
| ef_np = build_edge_features(df_prefix).astype(np.float32) |
| ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] |
|
|
| all_ids = np.union1d(df_prefix["sender_id"].values, df_prefix["receiver_id"].values) |
| n_nodes = max(int(all_ids.max()) + 1, self._n_nodes) |
| node_feat = self._build_node_feat(df_prefix, ef_np, n_nodes) |
| x = torch.tensor(node_feat, dtype=torch.float32, device=device) |
| edge_index = torch.tensor( |
| np.vstack([df_prefix["sender_id"].values, df_prefix["receiver_id"].values]), |
| dtype=torch.long, device=device, |
| ) |
|
|
| self._encoder.eval() |
| with torch.no_grad(): |
| return self._encoder(x, edge_index) |
|
|
| |
|
|
| def _build_node_feat( |
| self, df: pd.DataFrame, ef_np: np.ndarray, n_nodes: int |
| ) -> np.ndarray: |
| """Aggregate edge features to sender nodes (mean).""" |
| feat = np.zeros((n_nodes, ef_np.shape[1]), dtype=np.float32) |
| cnt = np.zeros(n_nodes, dtype=np.float32) |
| sids = df["sender_id"].values.astype(np.int64) |
| np.add.at(feat, sids, ef_np) |
| np.add.at(cnt, sids, 1.0) |
| cnt = np.maximum(cnt, 1.0) |
| return feat / cnt[:, None] |
|
|
| def _train_node_clf(self, df_train: pd.DataFrame, num_epochs: int = 150) -> None: |
| """Fine-tune node classifier on node-level fraud labels (training split).""" |
| device = self.device |
| emb = self._node_emb_agg |
| all_nodes = sorted(df_train["sender_id"].unique()) |
| eval_t = torch.tensor(all_nodes, dtype=torch.long, device=device) |
|
|
| |
| y_map = df_train.groupby("sender_id")["is_fraud"].max() |
| y_np = np.array([y_map.get(n, 0) for n in all_nodes], dtype=np.float32) |
| y = torch.tensor(y_np, device=device) |
|
|
| node_emb = emb[eval_t].detach() |
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) |
| opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) |
|
|
| self._node_clf.train() |
| for _ in range(num_epochs): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| self._node_clf.eval() |
|
|
| |
|
|
| def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: |
| assert self._encoder is not None, "Call fit() first." |
| leaked = _BLOCKED_COLS & set(df_eval.columns) |
| assert not leaked, f"Oracle columns leaked into StaticGNN.predict(): {leaked}" |
| device = self.device |
|
|
| ns = self._norm_stats |
|
|
| |
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) |
| ef_np = build_edge_features(df_eval).astype(np.float32) |
| ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] |
|
|
| all_ids = np.union1d(df_eval["sender_id"].values, df_eval["receiver_id"].values) |
| n_nodes = max(int(all_ids.max()) + 1, self._n_nodes) |
|
|
| node_feat = self._build_node_feat(df_eval, ef_np, n_nodes) |
| x = torch.tensor(node_feat, dtype=torch.float32, device=device) |
|
|
| edge_index = torch.tensor( |
| np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]), |
| dtype=torch.long, device=device, |
| ) |
|
|
| self._encoder.eval() |
| with torch.no_grad(): |
| h = self._encoder(x, edge_index) |
|
|
| eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| node_emb = h[eval_t] |
|
|
| with torch.no_grad(): |
| probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy() |
| return probs.astype(np.float32) |
|
|
| |
|
|
| def reset_memory(self) -> None: |
| """No-op: StaticGNN has no temporal memory.""" |
| pass |
|
|
| |
|
|
| def train_node_classifier( |
| self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150 |
| ) -> None: |
| """Re-train node classifier with fresh labels (for horizon sweep).""" |
| device = self.device |
| eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| node_emb = self._node_emb_agg[eval_t].detach() |
| y = torch.tensor(y_labels, dtype=torch.float32, device=device) |
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) |
| opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) |
| self._node_clf.train() |
| for _ in range(num_epochs): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| self._node_clf.eval() |
|
|
| def train_node_classifier_on_prefix( |
| self, |
| df_prefix: pd.DataFrame, |
| eval_nodes: List[int], |
| y_labels: np.ndarray, |
| num_epochs: int = 150, |
| ) -> None: |
| """Train the node classifier on embeddings computed from a causal prefix.""" |
| device = self.device |
| prefix_emb = self._compute_prefix_embeddings(df_prefix) |
| eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| node_emb = prefix_emb[eval_t].detach() |
| y = torch.tensor(y_labels, dtype=torch.float32, device=device) |
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) |
| opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) |
| self._node_clf.train() |
| for _ in range(num_epochs): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| self._node_clf.eval() |
|
|