temporal-twins-code / models /static_gnn.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
models/static_gnn.py
====================
Static GNN Baseline: GraphSAGE with Snapshot Batching
Architecture
------------
Events are binned into N time-snapshots (equal-count bins).
For each snapshot:
- Build a static homogeneous graph from the events in that bin
- Run 2-layer GraphSAGE to produce node embeddings
- Aggregate per-node embeddings across all snapshots (mean pooling)
A node classifier head is trained on the pooled embeddings.
This model has NO temporal memory between snapshots. It is the strongest
"static" baseline: it sees the full graph structure but cannot reason about
the ordering of events within or across snapshots.
Note: SAGEConv is used (from torch_geometric). Falls back gracefully when
a node has no edges in a snapshot (embedding stays at zero for that snapshot).
"""
from __future__ import annotations
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from models.base import TemporalModel
from src.graph.graph_builder import build_edge_features
_BLOCKED_COLS = frozenset({
"motif_hit_count", "motif_source", "trigger_event_idx", "label_event_idx",
"label_delay", "is_fallback_label", "fraud_source",
"twin_role", "twin_label", "twin_pair_id", "template_id",
"dynamic_fraud_state", "motif_chain_state", "motif_strength",
})
# ------------------------------------------------------------------ #
# Core GraphSAGE nn.Module #
# ------------------------------------------------------------------ #
class _SAGEEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
h = F.relu(self.norm1(self.conv1(x, edge_index)))
h = self.norm2(self.conv2(h, edge_index))
return h
# ------------------------------------------------------------------ #
# StaticGNNWrapper (TemporalModel interface) #
# ------------------------------------------------------------------ #
class StaticGNNWrapper(TemporalModel):
"""GraphSAGE with time-snapshot aggregation. No temporal memory."""
def __init__(
self,
hidden_dim: int = 64,
n_snapshots: int = 10,
device: str = "cpu",
):
self.hidden_dim = hidden_dim
self.n_snapshots = n_snapshots
self.device = torch.device(device)
self._encoder: _SAGEEncoder | None = None
self._node_clf: nn.Sequential | None = None
self._norm_stats: dict | None = None
self._n_nodes: int = 0
self._node_emb_agg: torch.Tensor | None = None # (n_nodes, hidden_dim)
self._in_dim: int = 0
@property
def name(self) -> str:
return "StaticGNN"
@property
def is_temporal(self) -> bool:
return False
# ------------------------------------------------------------------ #
def _build_snapshots(
self, df: pd.DataFrame, ef_np: np.ndarray
) -> List[tuple]:
"""
Returns list of (edge_index_t, edge_attr_t, src_nodes, dst_nodes)
for each snapshot bin.
"""
df = df.sort_values("timestamp").reset_index(drop=True)
n = len(df)
bin_size = max(1, n // self.n_snapshots)
snapshots = []
for b in range(self.n_snapshots):
lo = b * bin_size
hi = lo + bin_size if b < self.n_snapshots - 1 else n
sub_u = df["sender_id"].values[lo:hi].astype(np.int64)
sub_v = df["receiver_id"].values[lo:hi].astype(np.int64)
sub_e = ef_np[lo:hi]
edge_index = torch.tensor(np.vstack([sub_u, sub_v]), dtype=torch.long)
edge_attr = torch.tensor(sub_e, dtype=torch.float32)
snapshots.append((edge_index, edge_attr, sub_u, sub_v))
return snapshots
# ------------------------------------------------------------------ #
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
leaked = _BLOCKED_COLS & set(df_train.columns)
assert not leaked, f"Oracle columns leaked into StaticGNN.fit(): {leaked}"
df_train = df_train.sort_values("timestamp").reset_index(drop=True)
ef_np = build_edge_features(df_train).astype(np.float32)
edge_dim = ef_np.shape[1]
self._in_dim = edge_dim # node features are mean-pooled edge features per snapshot
ea_mean = ef_np.mean(axis=0)
ea_std = ef_np.std(axis=0) + 1e-6
ef_np = (ef_np - ea_mean) / ea_std
self._norm_stats = {"ea_mean": ea_mean, "ea_std": ea_std}
all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values)
n_nodes = int(all_ids.max()) + 1
self._n_nodes = n_nodes
device = self.device
# Node input features: mean of outgoing edge features per node (snapshot-level)
encoder = _SAGEEncoder(in_dim=edge_dim, hidden_dim=self.hidden_dim).to(device)
self._encoder = encoder
node_clf = nn.Sequential(
nn.Linear(self.hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
).to(device)
self._node_clf = node_clf
# Build snapshots
snapshots = self._build_snapshots(df_train, ef_np)
y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32)
raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6)
pos_weight = torch.clamp(raw_pw, max=10.0).to(device)
loss_fn_edge = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
opt = torch.optim.Adam(
list(encoder.parameters()) + list(node_clf.parameters()),
lr=1e-3,
)
# Build per-node input feature matrix: aggregate edge features to nodes
node_feat = self._build_node_feat(df_train, ef_np, n_nodes)
x_full = torch.tensor(node_feat, dtype=torch.float32, device=device)
for epoch in range(num_epochs):
encoder.train()
node_clf.train()
total_loss = 0.0
emb_accum = torch.zeros(n_nodes, self.hidden_dim, device=device)
snap_cnt = torch.zeros(n_nodes, dtype=torch.float32, device=device)
for snap_idx, (edge_index, edge_attr, src_np, _) in enumerate(snapshots):
edge_index = edge_index.to(device)
edge_attr = edge_attr.to(device)
# Get snapshot slice indices in original df
n = len(df_train)
bin_size = max(1, n // self.n_snapshots)
lo = snap_idx * bin_size
hi = lo + bin_size if snap_idx < self.n_snapshots - 1 else n
y_snap = y_all[lo:hi].to(device)
h = encoder(x_full, edge_index) # (n_nodes, hidden_dim)
# Edge-level fraud loss on this snapshot
src_t = edge_index[0]
dst_t = edge_index[1]
h_src = h[src_t]
h_dst = h[dst_t]
edge_logits = (h_src * h_dst).sum(dim=-1) # dot-product score
edge_logits = torch.clamp(edge_logits, -10, 10)
loss = loss_fn_edge(edge_logits, y_snap)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
opt.step()
total_loss += loss.item()
# Accumulate node embeddings across snapshots (detached)
with torch.no_grad():
emb_accum += h.detach()
snap_cnt += 1.0
# Pooled node embedding
emb_pooled = emb_accum / snap_cnt.unsqueeze(1).clamp(min=1.0)
self._node_emb_agg = emb_pooled.clone()
print(f"[StaticGNN] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}")
# Freeze encoder; train node classifier on pooled embeddings
self._train_node_clf(df_train)
# ------------------------------------------------------------------ #
def _compute_prefix_embeddings(self, df_prefix: pd.DataFrame) -> torch.Tensor:
"""Compute node embeddings for a causal prefix graph."""
device = self.device
ns = self._norm_stats
df_prefix = df_prefix.sort_values("timestamp").reset_index(drop=True)
ef_np = build_edge_features(df_prefix).astype(np.float32)
ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"]
all_ids = np.union1d(df_prefix["sender_id"].values, df_prefix["receiver_id"].values)
n_nodes = max(int(all_ids.max()) + 1, self._n_nodes)
node_feat = self._build_node_feat(df_prefix, ef_np, n_nodes)
x = torch.tensor(node_feat, dtype=torch.float32, device=device)
edge_index = torch.tensor(
np.vstack([df_prefix["sender_id"].values, df_prefix["receiver_id"].values]),
dtype=torch.long, device=device,
)
self._encoder.eval()
with torch.no_grad():
return self._encoder(x, edge_index)
# ------------------------------------------------------------------ #
def _build_node_feat(
self, df: pd.DataFrame, ef_np: np.ndarray, n_nodes: int
) -> np.ndarray:
"""Aggregate edge features to sender nodes (mean)."""
feat = np.zeros((n_nodes, ef_np.shape[1]), dtype=np.float32)
cnt = np.zeros(n_nodes, dtype=np.float32)
sids = df["sender_id"].values.astype(np.int64)
np.add.at(feat, sids, ef_np)
np.add.at(cnt, sids, 1.0)
cnt = np.maximum(cnt, 1.0)
return feat / cnt[:, None]
def _train_node_clf(self, df_train: pd.DataFrame, num_epochs: int = 150) -> None:
"""Fine-tune node classifier on node-level fraud labels (training split)."""
device = self.device
emb = self._node_emb_agg # (n_nodes, hidden_dim)
all_nodes = sorted(df_train["sender_id"].unique())
eval_t = torch.tensor(all_nodes, dtype=torch.long, device=device)
# Build node-level labels: any fraud in the training window?
y_map = df_train.groupby("sender_id")["is_fraud"].max()
y_np = np.array([y_map.get(n, 0) for n in all_nodes], dtype=np.float32)
y = torch.tensor(y_np, device=device)
node_emb = emb[eval_t].detach()
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3)
self._node_clf.train()
for _ in range(num_epochs):
logits = self._node_clf(node_emb).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
self._node_clf.eval()
# ------------------------------------------------------------------ #
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
assert self._encoder is not None, "Call fit() first."
leaked = _BLOCKED_COLS & set(df_eval.columns)
assert not leaked, f"Oracle columns leaked into StaticGNN.predict(): {leaked}"
device = self.device
ns = self._norm_stats
# Build node embeddings from eval graph (no memory — static)
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True)
ef_np = build_edge_features(df_eval).astype(np.float32)
ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"]
all_ids = np.union1d(df_eval["sender_id"].values, df_eval["receiver_id"].values)
n_nodes = max(int(all_ids.max()) + 1, self._n_nodes)
node_feat = self._build_node_feat(df_eval, ef_np, n_nodes)
x = torch.tensor(node_feat, dtype=torch.float32, device=device)
edge_index = torch.tensor(
np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]),
dtype=torch.long, device=device,
)
self._encoder.eval()
with torch.no_grad():
h = self._encoder(x, edge_index) # (n_nodes, hidden_dim)
eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
node_emb = h[eval_t]
with torch.no_grad():
probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy()
return probs.astype(np.float32)
# ------------------------------------------------------------------ #
def reset_memory(self) -> None:
"""No-op: StaticGNN has no temporal memory."""
pass
# ------------------------------------------------------------------ #
def train_node_classifier(
self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150
) -> None:
"""Re-train node classifier with fresh labels (for horizon sweep)."""
device = self.device
eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
node_emb = self._node_emb_agg[eval_t].detach()
y = torch.tensor(y_labels, dtype=torch.float32, device=device)
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3)
self._node_clf.train()
for _ in range(num_epochs):
logits = self._node_clf(node_emb).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
self._node_clf.eval()
def train_node_classifier_on_prefix(
self,
df_prefix: pd.DataFrame,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
"""Train the node classifier on embeddings computed from a causal prefix."""
device = self.device
prefix_emb = self._compute_prefix_embeddings(df_prefix)
eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
node_emb = prefix_emb[eval_t].detach()
y = torch.tensor(y_labels, dtype=torch.float32, device=device)
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3)
self._node_clf.train()
for _ in range(num_epochs):
logits = self._node_clf(node_emb).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
self._node_clf.eval()