| """ |
| models/jodie.py |
| =============== |
| JODIE: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks |
| Kumar et al., KDD 2019 |
| |
| Architecture |
| ------------ |
| JODIE maintains dual dynamic embeddings — one per node role: |
| - User (sender) embedding: h_u ← updated on each outgoing event |
| - Item (receiver) embedding: h_v ← updated on each incoming event |
| |
| Key ideas: |
| 1. Time projection: Before each update, project the existing embedding forward |
| in time using a learned linear transformation conditioned on Δt: |
| ĥ_u(t) = (1 + W_u · Δt_emb) ⊙ h_u [element-wise time scaling] |
| where Δt_emb = linear(Δt) is a learnable time embedding. |
| |
| 2. RNN update: After projection, the RNN ingests the *other node's projected |
| embedding* concatenated with edge features: |
| h_u ← RNN( cat(ĥ_v, edge_feat), ĥ_u ) |
| h_v ← RNN( cat(ĥ_u, edge_feat), ĥ_v ) |
| |
| 3. Node classifier: operates on the latest projected h_u at evaluation time. |
| |
| This is a faithful re-implementation of the JODIE equations from the KDD'19 paper, |
| adapted to the event-stream training loop of the upi-sim benchmark. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import List |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
|
|
| from models.base import TemporalModel |
| from src.graph.graph_builder import build_edge_features |
|
|
|
|
| |
| |
| |
|
|
| class _JODIEModule(nn.Module): |
| def __init__(self, memory_dim: int, edge_dim: int, time_emb_dim: int = 16): |
| super().__init__() |
| self.memory_dim = memory_dim |
|
|
| |
| self.time_emb = nn.Linear(1, time_emb_dim) |
|
|
| |
| self.W_proj_u = nn.Linear(time_emb_dim, memory_dim, bias=False) |
| self.W_proj_v = nn.Linear(time_emb_dim, memory_dim, bias=False) |
|
|
| |
| self.rnn_u = nn.GRUCell(memory_dim + edge_dim, memory_dim) |
| self.rnn_v = nn.GRUCell(memory_dim + edge_dim, memory_dim) |
|
|
| |
| self.norm_u = nn.LayerNorm(memory_dim) |
| self.norm_v = nn.LayerNorm(memory_dim) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(memory_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ) |
|
|
| def project( |
| self, |
| h: torch.Tensor, |
| dt: torch.Tensor, |
| W_proj: nn.Linear, |
| ) -> torch.Tensor: |
| """Time-projection: ĥ = (1 + W_proj(φ(Δt))) ⊙ h. |
| Clamp Δt and the scale factor to prevent explosions with large time gaps. |
| """ |
| dt_clamped = dt.clamp(0.0, 5.0) |
| dt_emb = torch.relu(self.time_emb(dt_clamped.unsqueeze(-1))) |
| scale = (1.0 + W_proj(dt_emb)).clamp(-2.0, 2.0) |
| return scale * h |
|
|
| def update( |
| self, |
| h_self: torch.Tensor, |
| h_other: torch.Tensor, |
| edge_feat: torch.Tensor, |
| rnn: nn.GRUCell, |
| norm: nn.LayerNorm, |
| ) -> torch.Tensor: |
| inp = torch.cat([h_other, edge_feat], dim=-1) |
| out = rnn(inp, h_self) |
| return norm(out) |
|
|
| def classify(self, h: torch.Tensor) -> torch.Tensor: |
| return self.classifier(h).squeeze(-1) |
|
|
|
|
|
|
| |
| |
| |
|
|
| class JODIEWrapper(TemporalModel): |
| """JODIE dual-RNN temporal model with time-projection embeddings.""" |
|
|
| def __init__( |
| self, |
| memory_dim: int = 64, |
| time_emb_dim: int = 16, |
| device: str = "cpu", |
| ): |
| self.memory_dim = memory_dim |
| self.time_emb_dim = time_emb_dim |
| self.device = torch.device(device) |
|
|
| self._module: _JODIEModule | None = None |
| self._memory: torch.Tensor | None = None |
| self._last_t: torch.Tensor | None = None |
| self._norm_stats: dict | None = None |
| self._n_nodes: int = 0 |
| self._edge_dim: int = 0 |
|
|
| @property |
| def name(self) -> str: |
| return "JODIE" |
|
|
| |
|
|
| def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None: |
| df_train = df_train.sort_values("timestamp").reset_index(drop=True) |
|
|
| ef_np = build_edge_features(df_train).astype(np.float32) |
| edge_dim = ef_np.shape[1] |
| self._edge_dim = edge_dim |
|
|
| ea_mean = ef_np.mean(axis=0) |
| ea_std = ef_np.std(axis=0) + 1e-6 |
| ef_np = (ef_np - ea_mean) / ea_std |
|
|
| t_vals = df_train["timestamp"].values.astype(np.float32) |
| t_min, t_max = t_vals.min(), t_vals.max() |
| t_norm = (t_vals - t_min) / (t_max - t_min + 1e-6) |
|
|
| self._norm_stats = { |
| "ea_mean": ea_mean, "ea_std": ea_std, |
| "t_min": t_min, "t_max": t_max, |
| } |
|
|
| all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values) |
| n_nodes = int(all_ids.max()) + 1 |
| self._n_nodes = n_nodes |
|
|
| module = _JODIEModule( |
| memory_dim=self.memory_dim, |
| edge_dim=edge_dim, |
| time_emb_dim=self.time_emb_dim, |
| ).to(self.device) |
| self._module = module |
|
|
| memory = torch.zeros(n_nodes, self.memory_dim, device=self.device) |
| last_t = torch.zeros(n_nodes, device=self.device) |
| self._memory = memory |
| self._last_t = last_t |
|
|
| u_ids = torch.tensor(df_train["sender_id"].values, dtype=torch.long) |
| v_ids = torch.tensor(df_train["receiver_id"].values, dtype=torch.long) |
| ef_all = torch.tensor(ef_np, dtype=torch.float32) |
| t_all = torch.tensor(t_norm, dtype=torch.float32) |
| y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32) |
|
|
| raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6) |
| pos_weight = torch.clamp(raw_pw, max=10.0).to(self.device) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
| |
| edge_clf = nn.Sequential( |
| nn.Linear(self.memory_dim * 2 + edge_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ).to(self.device) |
| self._edge_clf = edge_clf |
|
|
| opt = torch.optim.Adam( |
| list(module.parameters()) + list(edge_clf.parameters()), |
| lr=1e-3, |
| ) |
|
|
| batch_size = 512 |
| N = len(df_train) |
|
|
| for epoch in range(num_epochs): |
| memory.zero_() |
| last_t.zero_() |
| total_loss = 0.0 |
|
|
| for i in range(0, N, batch_size): |
| j = min(i + batch_size, N) |
| u_b = u_ids[i:j].to(self.device) |
| v_b = v_ids[i:j].to(self.device) |
| t_b = t_all[i:j].to(self.device) |
| ef_b = ef_all[i:j].to(self.device) |
| y_b = y_all[i:j].to(self.device) |
|
|
| h_u = memory[u_b].clone() |
| h_v = memory[v_b].clone() |
| dt_u = (t_b - last_t[u_b]).clamp(min=0.0) |
| dt_v = (t_b - last_t[v_b]).clamp(min=0.0) |
|
|
| |
| h_u_proj = module.project(h_u.detach(), dt_u, module.W_proj_u) |
| h_v_proj = module.project(h_v.detach(), dt_v, module.W_proj_v) |
|
|
| |
| h_u_new = module.update(h_u_proj, h_v_proj.detach(), ef_b, module.rnn_u, module.norm_u) |
| h_v_new = module.update(h_v_proj, h_u_proj.detach(), ef_b, module.rnn_v, module.norm_v) |
|
|
| |
| both = torch.cat([u_b, v_b]) |
| both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0) |
| unique_ids, inv = torch.unique(both, return_inverse=True) |
| agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device) |
| agg_h.index_add_(0, inv, both_h.detach()) |
| cnt = torch.bincount(inv).unsqueeze(1).float() |
| memory[unique_ids] = agg_h / cnt |
| last_t[u_b] = t_b |
| last_t[v_b] = t_b |
|
|
| |
| ef_concat = torch.cat([h_u_new, h_v_new, ef_b], dim=-1) |
| logits = edge_clf(ef_concat).squeeze(-1) |
| logits = torch.clamp(logits, -10, 10) |
| loss = loss_fn(logits, y_b) |
|
|
| if not torch.isnan(loss): |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(module.parameters(), 1.0) |
| opt.step() |
| total_loss += loss.item() |
|
|
| print(f"[JODIE] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}") |
|
|
| |
| self._node_clf = nn.Sequential( |
| nn.Linear(self.memory_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ).to(self.device) |
|
|
| |
|
|
| def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray: |
| assert self._module is not None, "Call fit() first." |
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True) |
| device = self.device |
| module = self._module |
| memory = self._memory |
| last_t = self._last_t |
| ns = self._norm_stats |
|
|
| ef_np = build_edge_features(df_eval).astype(np.float32) |
| ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] |
| t_vals = df_eval["timestamp"].values.astype(np.float32) |
| t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) |
|
|
| u_ids = torch.tensor(df_eval["sender_id"].values, dtype=torch.long) |
| v_ids = torch.tensor(df_eval["receiver_id"].values, dtype=torch.long) |
| ef_t = torch.tensor(ef_np, dtype=torch.float32) |
| t_t = torch.tensor(t_norm, dtype=torch.float32) |
|
|
| module.eval() |
| batch_size = 512 |
| with torch.no_grad(): |
| for i in range(0, len(df_eval), batch_size): |
| j = min(i + batch_size, len(df_eval)) |
| u_b = u_ids[i:j].to(device) |
| v_b = v_ids[i:j].to(device) |
| t_b = t_t[i:j].to(device) |
| ef_b = ef_t[i:j].to(device) |
|
|
| h_u = memory[u_b].clone() |
| h_v = memory[v_b].clone() |
| dt_u = (t_b - last_t[u_b]).clamp(min=0.0) |
| dt_v = (t_b - last_t[v_b]).clamp(min=0.0) |
|
|
| h_u_proj = module.project(h_u, dt_u, module.W_proj_u) |
| h_v_proj = module.project(h_v, dt_v, module.W_proj_v) |
|
|
| h_u_new = module.update(h_u_proj, h_v_proj, ef_b, module.rnn_u, module.norm_u) |
| h_v_new = module.update(h_v_proj, h_u_proj, ef_b, module.rnn_v, module.norm_v) |
|
|
| both = torch.cat([u_b, v_b]) |
| both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0) |
| unique_ids, inv = torch.unique(both, return_inverse=True) |
| agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=device) |
| agg_h.index_add_(0, inv, both_h) |
| cnt = torch.bincount(inv).unsqueeze(1).float() |
| memory[unique_ids] = agg_h / cnt |
| last_t[u_b] = t_b |
| last_t[v_b] = t_b |
|
|
| eval_t = torch.tensor( |
| [min(n, self._n_nodes - 1) for n in eval_nodes], |
| dtype=torch.long, device=device, |
| ) |
| node_emb = memory[eval_t] |
| |
| if not hasattr(self, "_node_clf") or self._node_clf is None: |
| self._node_clf = nn.Sequential( |
| nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1) |
| ).to(device) |
| with torch.no_grad(): |
| probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy() |
| return probs.astype(np.float32) |
|
|
| def extract_prefix_embeddings( |
| self, |
| df_eval: pd.DataFrame, |
| examples: pd.DataFrame, |
| ) -> np.ndarray: |
| assert self._module is not None, "Call fit() first." |
| if examples.empty: |
| return np.zeros((0, self.memory_dim), dtype=np.float32) |
|
|
| df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy() |
| if "local_event_idx" not in df_eval.columns: |
| df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32) |
|
|
| capture_map: dict[tuple[int, int], list[int]] = {} |
| for ex_idx, row in enumerate(examples.itertuples(index=False)): |
| key = (int(row.sender_id), int(row.eval_local_event_idx)) |
| capture_map.setdefault(key, []).append(ex_idx) |
|
|
| max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1 |
| memory = torch.zeros(max(self._n_nodes, max_seen_id), self.memory_dim, device=self.device) |
| last_t = torch.zeros(max(self._n_nodes, max_seen_id), device=self.device) |
| ns = self._norm_stats |
| module = self._module |
|
|
| ef_np = build_edge_features(df_eval).astype(np.float32) |
| ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"] |
| t_vals = df_eval["timestamp"].to_numpy(dtype=np.float32) |
| t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) |
|
|
| out = np.zeros((len(examples), self.memory_dim), dtype=np.float32) |
| module.eval() |
| with torch.no_grad(): |
| for idx, row in enumerate(df_eval.itertuples(index=False)): |
| u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=self.device) |
| v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=self.device) |
| t = torch.tensor([t_norm[idx]], dtype=torch.float32, device=self.device) |
| ef = torch.tensor(ef_np[idx:idx + 1], dtype=torch.float32, device=self.device) |
|
|
| h_u = memory[u].clone() |
| h_v = memory[v].clone() |
| dt_u = (t - last_t[u]).clamp(min=0.0) |
| dt_v = (t - last_t[v]).clamp(min=0.0) |
|
|
| h_u_proj = module.project(h_u, dt_u, module.W_proj_u) |
| h_v_proj = module.project(h_v, dt_v, module.W_proj_v) |
| h_u_new = module.update(h_u_proj, h_v_proj, ef, module.rnn_u, module.norm_u) |
| h_v_new = module.update(h_v_proj, h_u_proj, ef, module.rnn_v, module.norm_v) |
|
|
| both_ids = torch.cat([u, v]) |
| both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0) |
| unique_ids, inv = torch.unique(both_ids, return_inverse=True) |
| agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device) |
| agg_h.index_add_(0, inv, both_h) |
| cnt = torch.bincount(inv).unsqueeze(1).float() |
| memory[unique_ids] = agg_h / cnt |
| last_t[u] = t |
| last_t[v] = t |
|
|
| key = (int(row.sender_id), int(row.local_event_idx)) |
| if key in capture_map: |
| emb = memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32) |
| for ex_idx in capture_map[key]: |
| out[ex_idx] = emb |
|
|
| return out |
|
|
| |
|
|
| def reset_memory(self) -> None: |
| if self._memory is not None: |
| self._memory.zero_() |
| self._last_t.zero_() |
|
|
| |
|
|
| def train_node_classifier( |
| self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150 |
| ) -> None: |
| device = self.device |
| eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device) |
| node_emb = self._memory[eval_t].detach() |
| y = torch.tensor(y_labels, dtype=torch.float32, device=device) |
| pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) |
| opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3) |
| self._node_clf.train() |
| for _ in range(num_epochs): |
| logits = self._node_clf(node_emb).squeeze(-1) |
| loss = loss_fn(logits, y) |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| self._node_clf.eval() |
|
|