""" models/dyrep.py =============== DyRep: Learning Representations over Dynamic Graphs Trivedi et al., NeurIPS 2019 Architecture ------------ DyRep models the evolution of node representations via two interleaved processes: 1. Communication (association): A new edge (u,v,t) triggers mutual updates h_u ← GRU(h_u, msg(h_u, h_v, Δt_u, e)) h_v ← GRU(h_v, msg(h_v, h_u, Δt_v, e)) 2. No explicit "propagation" process is used here; the GRU-based update already serves the equivalent role in our streaming setting. Message is conditioned on: - Current embeddings of both endpoints (h_u, h_v) - Time since last interaction for each node (Δt_u, Δt_v) → sinusoidal encoding - Edge features Intensity function λ(u,v,t) is learnt via a bilinear form and used as a proxy training signal (event likelihood maximisation), augmented by a BCE edge-fraud loss. This follows the original paper's framing closely while being adapted to the event-stream training loop of the upi-sim benchmark. """ from __future__ import annotations from typing import List import numpy as np import pandas as pd import torch import torch.nn as nn from models.base import TemporalModel from src.graph.graph_builder import build_edge_features from src.tgn.time_encoding import TimeEncoding # ------------------------------------------------------------------ # # Core DyRep nn.Module # # ------------------------------------------------------------------ # class _DyRepModule(nn.Module): def __init__(self, memory_dim: int, edge_dim: int, time_dim: int): super().__init__() self.memory_dim = memory_dim self.time_enc = TimeEncoding(time_dim) # Message function: h_u, h_v, φ(Δt), edge → message self.msg_fn = nn.Sequential( nn.Linear(2 * memory_dim + 2 * time_dim + edge_dim, memory_dim), nn.Tanh(), nn.Linear(memory_dim, memory_dim), ) # GRU cell for memory update self.gru = nn.GRUCell(memory_dim, memory_dim) # Intensity function: bilinear score between endpoint embeddings # λ(u,v,t) = sigmoid(h_u^T W h_v) self.W_intensity = nn.Bilinear(memory_dim, memory_dim, 1) # Node fraud classifier self.classifier = nn.Sequential( nn.Linear(memory_dim, 64), nn.ReLU(), nn.Linear(64, 1), ) def compute_message( self, h_src: torch.Tensor, # (B, mem_dim) h_dst: torch.Tensor, # (B, mem_dim) dt: torch.Tensor, # (B,) — time since last event for src edge_feat: torch.Tensor, # (B, edge_dim) ) -> torch.Tensor: phi_dt = self.time_enc(dt) # (B, 2*time_dim) inp = torch.cat([h_src, h_dst, phi_dt, edge_feat], dim=-1) return self.msg_fn(inp) def intensity(self, h_u: torch.Tensor, h_v: torch.Tensor) -> torch.Tensor: """Hawkes-like point-process intensity.""" return torch.sigmoid(self.W_intensity(h_u, h_v).squeeze(-1)) def classify(self, h: torch.Tensor) -> torch.Tensor: return self.classifier(h).squeeze(-1) # ------------------------------------------------------------------ # # DyRepWrapper (TemporalModel interface) # # ------------------------------------------------------------------ # class DyRepWrapper(TemporalModel): """DyRep intensity-based temporal model.""" def __init__( self, memory_dim: int = 64, time_dim: int = 8, device: str = "cpu", ): self.memory_dim = memory_dim self.time_dim = time_dim self.device = torch.device(device) self._module: _DyRepModule | None = None self._memory: torch.Tensor | None = None # (n_nodes, mem_dim) self._last_t: torch.Tensor | None = None # (n_nodes,) last event time self._norm_stats: dict | None = None self._n_nodes: int = 0 @property def name(self) -> str: return "DyRep" # ------------------------------------------------------------------ # def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: df_train = df_train.sort_values("timestamp").reset_index(drop=True) ef_np = build_edge_features(df_train).astype(np.float32) edge_dim = ef_np.shape[1] ea_mean = ef_np.mean(axis=0) ea_std = ef_np.std(axis=0) + 1e-6 ef_np = (ef_np - ea_mean) / ea_std t_vals = df_train["timestamp"].values.astype(np.float32) t_min, t_max = t_vals.min(), t_vals.max() t_norm = (t_vals - t_min) / (t_max - t_min + 1e-6) * 5.0 self._norm_stats = { "ea_mean": ea_mean, "ea_std": ea_std, "t_min": t_min, "t_max": t_max, } all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values) n_nodes = int(all_ids.max()) + 1 self._n_nodes = n_nodes module = _DyRepModule( memory_dim=self.memory_dim, edge_dim=edge_dim, time_dim=self.time_dim, ).to(self.device) self._module = module memory = torch.zeros(n_nodes, self.memory_dim, device=self.device) last_t = torch.zeros(n_nodes, device=self.device) self._memory = memory self._last_t = last_t u_ids = torch.tensor(df_train["sender_id"].values, dtype=torch.long) v_ids = torch.tensor(df_train["receiver_id"].values, dtype=torch.long) ef_all = torch.tensor(ef_np, dtype=torch.float32) t_all = torch.tensor(t_norm, dtype=torch.float32) y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32) raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6) pos_weight = torch.clamp(raw_pw, max=10.0).to(self.device) bce_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Edge-level classifier for proxy training edge_clf = nn.Sequential( nn.Linear(self.memory_dim * 2 + edge_dim, 64), nn.ReLU(), nn.Linear(64, 1), ).to(self.device) self._edge_clf = edge_clf opt = torch.optim.Adam( list(module.parameters()) + list(edge_clf.parameters()), lr=1e-3, ) batch_size = 512 N = len(df_train) for epoch in range(num_epochs): memory.zero_() last_t.zero_() total_loss = 0.0 for i in range(0, N, batch_size): j = min(i + batch_size, N) u_b = u_ids[i:j].to(self.device) v_b = v_ids[i:j].to(self.device) t_b = t_all[i:j].to(self.device) ef_b = ef_all[i:j].to(self.device) y_b = y_all[i:j].to(self.device) h_u = memory[u_b].clone() h_v = memory[v_b].clone() dt_u = (t_b - last_t[u_b]).clamp(min=0.0) dt_v = (t_b - last_t[v_b]).clamp(min=0.0) # DyRep: both nodes update using each other's context msg_u = module.compute_message(h_u, h_v.detach(), dt_u, ef_b) msg_v = module.compute_message(h_v, h_u.detach(), dt_v, ef_b) h_u_new = module.gru(msg_u, h_u.detach()) h_v_new = module.gru(msg_v, h_v.detach()) # Scatter memory updates (unique-node safe) both_ids = torch.cat([u_b, v_b]) both_h = torch.cat([h_u_new, h_v_new], dim=0) unique_ids, inv = torch.unique(both_ids, return_inverse=True) agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device) agg_h.index_add_(0, inv, both_h.detach()) cnt = torch.bincount(inv).unsqueeze(1).float() memory[unique_ids] = agg_h / cnt last_t[u_b] = t_b last_t[v_b] = t_b # --- Loss -------------------------------------------------------- # 1. Intensity (event likelihood) — regression to 1 for observed edges lam = module.intensity(h_u_new, h_v_new) intensity_loss = -torch.log(lam + 1e-8).mean() # 2. Edge-level fraud classification ef_concat = torch.cat([h_u_new, h_v_new, ef_b], dim=-1) logits = edge_clf(ef_concat).squeeze(-1) logits = torch.clamp(logits, -10, 10) fraud_loss = bce_fn(logits, y_b) loss = fraud_loss + 0.1 * intensity_loss opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(module.parameters(), 1.0) opt.step() total_loss += loss.item() print(f"[DyRep] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}") # Node classifier head self._node_clf = nn.Sequential( nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1), ).to(self.device) # ------------------------------------------------------------------ # def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: assert self._module is not None, "Call fit() first." df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) device = self.device module = self._module memory = self._memory last_t = self._last_t ns = self._norm_stats ef_np = build_edge_features(df_eval).astype(np.float32) ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] t_vals = df_eval["timestamp"].values.astype(np.float32) t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) * 5.0 u_ids = torch.tensor(df_eval["sender_id"].values, dtype=torch.long) v_ids = torch.tensor(df_eval["receiver_id"].values, dtype=torch.long) ef_t = torch.tensor(ef_np, dtype=torch.float32) t_t = torch.tensor(t_norm, dtype=torch.float32) module.eval() batch_size = 512 with torch.no_grad(): for i in range(0, len(df_eval), batch_size): j = min(i + batch_size, len(df_eval)) u_b = u_ids[i:j].to(device) v_b = v_ids[i:j].to(device) t_b = t_t[i:j].to(device) ef_b = ef_t[i:j].to(device) h_u = memory[u_b].clone() h_v = memory[v_b].clone() dt_u = (t_b - last_t[u_b]).clamp(min=0.0) msg_u = module.compute_message(h_u, h_v, dt_u, ef_b) h_u_new = module.gru(msg_u, h_u) msg_v = module.compute_message(h_v, h_u, (t_b - last_t[v_b]).clamp(min=0.0), ef_b) h_v_new = module.gru(msg_v, h_v) both = torch.cat([u_b, v_b]) both_h = torch.cat([h_u_new, h_v_new], dim=0) unique_ids, inv = torch.unique(both, return_inverse=True) agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=device) agg_h.index_add_(0, inv, both_h) cnt = torch.bincount(inv).unsqueeze(1).float() memory[unique_ids] = agg_h / cnt last_t[u_b] = t_b last_t[v_b] = t_b eval_t = torch.tensor( [min(n, self._n_nodes - 1) for n in eval_nodes], dtype=torch.long, device=device, ) node_emb = memory[eval_t] if not hasattr(self, "_node_clf") or self._node_clf is None: self._node_clf = nn.Sequential( nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1) ).to(device) with torch.no_grad(): probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy() return probs.astype(np.float32) def extract_prefix_embeddings( self, df_eval: pd.DataFrame, examples: pd.DataFrame, ) -> np.ndarray: assert self._module is not None, "Call fit() first." if examples.empty: return np.zeros((0, self.memory_dim), dtype=np.float32) df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy() if "local_event_idx" not in df_eval.columns: df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32) capture_map: dict[tuple[int, int], list[int]] = {} for ex_idx, row in enumerate(examples.itertuples(index=False)): key = (int(row.sender_id), int(row.eval_local_event_idx)) capture_map.setdefault(key, []).append(ex_idx) max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1 memory = torch.zeros(max(self._n_nodes, max_seen_id), self.memory_dim, device=self.device) last_t = torch.zeros(max(self._n_nodes, max_seen_id), device=self.device) ns = self._norm_stats module = self._module ef_np = build_edge_features(df_eval).astype(np.float32) ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] t_vals = df_eval["timestamp"].to_numpy(dtype=np.float32) t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) * 5.0 out = np.zeros((len(examples), self.memory_dim), dtype=np.float32) module.eval() with torch.no_grad(): for idx, row in enumerate(df_eval.itertuples(index=False)): u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=self.device) v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=self.device) t = torch.tensor([t_norm[idx]], dtype=torch.float32, device=self.device) ef = torch.tensor(ef_np[idx:idx + 1], dtype=torch.float32, device=self.device) h_u = memory[u].clone() h_v = memory[v].clone() dt_u = (t - last_t[u]).clamp(min=0.0) dt_v = (t - last_t[v]).clamp(min=0.0) msg_u = module.compute_message(h_u, h_v, dt_u, ef) msg_v = module.compute_message(h_v, h_u, dt_v, ef) h_u_new = module.gru(msg_u, h_u) h_v_new = module.gru(msg_v, h_v) both_ids = torch.cat([u, v]) both_h = torch.cat([h_u_new, h_v_new], dim=0) unique_ids, inv = torch.unique(both_ids, return_inverse=True) agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device) agg_h.index_add_(0, inv, both_h) cnt = torch.bincount(inv).unsqueeze(1).float() memory[unique_ids] = agg_h / cnt last_t[u] = t last_t[v] = t key = (int(row.sender_id), int(row.local_event_idx)) if key in capture_map: emb = memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32) for ex_idx in capture_map[key]: out[ex_idx] = emb return out # ------------------------------------------------------------------ # def reset_memory(self) -> None: if self._memory is not None: self._memory.zero_() self._last_t.zero_() # ------------------------------------------------------------------ # def train_node_classifier( self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150 ) -> None: device = self.device eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) node_emb = self._memory[eval_t].detach() y = torch.tensor(y_labels, dtype=torch.float32, device=device) pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) self._node_clf.train() for _ in range(num_epochs): logits = self._node_clf(node_emb).squeeze(-1) loss = loss_fn(logits, y) opt.zero_grad() loss.backward() opt.step() self._node_clf.eval()