temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
models/jodie.py
===============
JODIE: Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks
Kumar et al., KDD 2019
Architecture
------------
JODIE maintains dual dynamic embeddings — one per node role:
- User (sender) embedding: h_u ← updated on each outgoing event
- Item (receiver) embedding: h_v ← updated on each incoming event
Key ideas:
1. Time projection: Before each update, project the existing embedding forward
in time using a learned linear transformation conditioned on Δt:
ĥ_u(t) = (1 + W_u · Δt_emb) ⊙ h_u [element-wise time scaling]
where Δt_emb = linear(Δt) is a learnable time embedding.
2. RNN update: After projection, the RNN ingests the *other node's projected
embedding* concatenated with edge features:
h_u ← RNN( cat(ĥ_v, edge_feat), ĥ_u )
h_v ← RNN( cat(ĥ_u, edge_feat), ĥ_v )
3. Node classifier: operates on the latest projected h_u at evaluation time.
This is a faithful re-implementation of the JODIE equations from the KDD'19 paper,
adapted to the event-stream training loop of the upi-sim benchmark.
"""
from __future__ import annotations
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from models.base import TemporalModel
from src.graph.graph_builder import build_edge_features
# ------------------------------------------------------------------ #
# Core JODIE nn.Module #
# ------------------------------------------------------------------ #
class _JODIEModule(nn.Module):
def __init__(self, memory_dim: int, edge_dim: int, time_emb_dim: int = 16):
super().__init__()
self.memory_dim = memory_dim
# Time embedding: scalar Δt → vector
self.time_emb = nn.Linear(1, time_emb_dim)
# Projection: (1 + W · Δt_emb) ⊙ h — element-wise scale
self.W_proj_u = nn.Linear(time_emb_dim, memory_dim, bias=False)
self.W_proj_v = nn.Linear(time_emb_dim, memory_dim, bias=False)
# RNN: ingests projected other-node embedding + edge feature
self.rnn_u = nn.GRUCell(memory_dim + edge_dim, memory_dim)
self.rnn_v = nn.GRUCell(memory_dim + edge_dim, memory_dim)
# LayerNorm after GRU — critical for numerical stability with large Δt
self.norm_u = nn.LayerNorm(memory_dim)
self.norm_v = nn.LayerNorm(memory_dim)
# Node fraud classifier (applied to sender embedding)
self.classifier = nn.Sequential(
nn.Linear(memory_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def project(
self,
h: torch.Tensor, # (B, mem_dim)
dt: torch.Tensor, # (B,)
W_proj: nn.Linear,
) -> torch.Tensor:
"""Time-projection: ĥ = (1 + W_proj(φ(Δt))) ⊙ h.
Clamp Δt and the scale factor to prevent explosions with large time gaps.
"""
dt_clamped = dt.clamp(0.0, 5.0) # normalised Δt bounded [0, 5]
dt_emb = torch.relu(self.time_emb(dt_clamped.unsqueeze(-1))) # (B, time_emb_dim)
scale = (1.0 + W_proj(dt_emb)).clamp(-2.0, 2.0) # (B, mem_dim)
return scale * h
def update(
self,
h_self: torch.Tensor, # (B, mem_dim) current (projected)
h_other: torch.Tensor, # (B, mem_dim) other endpoint (projected)
edge_feat: torch.Tensor, # (B, edge_dim)
rnn: nn.GRUCell,
norm: nn.LayerNorm,
) -> torch.Tensor:
inp = torch.cat([h_other, edge_feat], dim=-1)
out = rnn(inp, h_self)
return norm(out) # stabilise magnitude after GRU
def classify(self, h: torch.Tensor) -> torch.Tensor:
return self.classifier(h).squeeze(-1)
# ------------------------------------------------------------------ #
# JODIEWrapper (TemporalModel interface) #
# ------------------------------------------------------------------ #
class JODIEWrapper(TemporalModel):
"""JODIE dual-RNN temporal model with time-projection embeddings."""
def __init__(
self,
memory_dim: int = 64,
time_emb_dim: int = 16,
device: str = "cpu",
):
self.memory_dim = memory_dim
self.time_emb_dim = time_emb_dim
self.device = torch.device(device)
self._module: _JODIEModule | None = None
self._memory: torch.Tensor | None = None # (n_nodes, mem_dim)
self._last_t: torch.Tensor | None = None # (n_nodes,)
self._norm_stats: dict | None = None
self._n_nodes: int = 0
self._edge_dim: int = 0
@property
def name(self) -> str:
return "JODIE"
# ------------------------------------------------------------------ #
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
df_train = df_train.sort_values("timestamp").reset_index(drop=True)
ef_np = build_edge_features(df_train).astype(np.float32)
edge_dim = ef_np.shape[1]
self._edge_dim = edge_dim
ea_mean = ef_np.mean(axis=0)
ea_std = ef_np.std(axis=0) + 1e-6
ef_np = (ef_np - ea_mean) / ea_std
t_vals = df_train["timestamp"].values.astype(np.float32)
t_min, t_max = t_vals.min(), t_vals.max()
t_norm = (t_vals - t_min) / (t_max - t_min + 1e-6)
self._norm_stats = {
"ea_mean": ea_mean, "ea_std": ea_std,
"t_min": t_min, "t_max": t_max,
}
all_ids = np.union1d(df_train["sender_id"].values, df_train["receiver_id"].values)
n_nodes = int(all_ids.max()) + 1
self._n_nodes = n_nodes
module = _JODIEModule(
memory_dim=self.memory_dim,
edge_dim=edge_dim,
time_emb_dim=self.time_emb_dim,
).to(self.device)
self._module = module
memory = torch.zeros(n_nodes, self.memory_dim, device=self.device)
last_t = torch.zeros(n_nodes, device=self.device)
self._memory = memory
self._last_t = last_t
u_ids = torch.tensor(df_train["sender_id"].values, dtype=torch.long)
v_ids = torch.tensor(df_train["receiver_id"].values, dtype=torch.long)
ef_all = torch.tensor(ef_np, dtype=torch.float32)
t_all = torch.tensor(t_norm, dtype=torch.float32)
y_all = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32)
raw_pw = (y_all == 0).sum() / ((y_all == 1).sum() + 1e-6)
pos_weight = torch.clamp(raw_pw, max=10.0).to(self.device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# Edge-level classifier for proxy supervision during training
edge_clf = nn.Sequential(
nn.Linear(self.memory_dim * 2 + edge_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
).to(self.device)
self._edge_clf = edge_clf
opt = torch.optim.Adam(
list(module.parameters()) + list(edge_clf.parameters()),
lr=1e-3,
)
batch_size = 512
N = len(df_train)
for epoch in range(num_epochs):
memory.zero_()
last_t.zero_()
total_loss = 0.0
for i in range(0, N, batch_size):
j = min(i + batch_size, N)
u_b = u_ids[i:j].to(self.device)
v_b = v_ids[i:j].to(self.device)
t_b = t_all[i:j].to(self.device)
ef_b = ef_all[i:j].to(self.device)
y_b = y_all[i:j].to(self.device)
h_u = memory[u_b].clone()
h_v = memory[v_b].clone()
dt_u = (t_b - last_t[u_b]).clamp(min=0.0)
dt_v = (t_b - last_t[v_b]).clamp(min=0.0)
# Time projection
h_u_proj = module.project(h_u.detach(), dt_u, module.W_proj_u)
h_v_proj = module.project(h_v.detach(), dt_v, module.W_proj_v)
# JODIE update (LayerNorm inside update() for stability)
h_u_new = module.update(h_u_proj, h_v_proj.detach(), ef_b, module.rnn_u, module.norm_u)
h_v_new = module.update(h_v_proj, h_u_proj.detach(), ef_b, module.rnn_v, module.norm_v)
# Scatter-based memory write — guard against NaN
both = torch.cat([u_b, v_b])
both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0)
unique_ids, inv = torch.unique(both, return_inverse=True)
agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device)
agg_h.index_add_(0, inv, both_h.detach())
cnt = torch.bincount(inv).unsqueeze(1).float()
memory[unique_ids] = agg_h / cnt
last_t[u_b] = t_b
last_t[v_b] = t_b
# Loss: edge-level fraud classification
ef_concat = torch.cat([h_u_new, h_v_new, ef_b], dim=-1)
logits = edge_clf(ef_concat).squeeze(-1)
logits = torch.clamp(logits, -10, 10)
loss = loss_fn(logits, y_b)
if not torch.isnan(loss):
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(module.parameters(), 1.0)
opt.step()
total_loss += loss.item()
print(f"[JODIE] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}")
# Node classifier on sender memory
self._node_clf = nn.Sequential(
nn.Linear(self.memory_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
).to(self.device)
# ------------------------------------------------------------------ #
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
assert self._module is not None, "Call fit() first."
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True)
device = self.device
module = self._module
memory = self._memory
last_t = self._last_t
ns = self._norm_stats
ef_np = build_edge_features(df_eval).astype(np.float32)
ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"]
t_vals = df_eval["timestamp"].values.astype(np.float32)
t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6)
u_ids = torch.tensor(df_eval["sender_id"].values, dtype=torch.long)
v_ids = torch.tensor(df_eval["receiver_id"].values, dtype=torch.long)
ef_t = torch.tensor(ef_np, dtype=torch.float32)
t_t = torch.tensor(t_norm, dtype=torch.float32)
module.eval()
batch_size = 512
with torch.no_grad():
for i in range(0, len(df_eval), batch_size):
j = min(i + batch_size, len(df_eval))
u_b = u_ids[i:j].to(device)
v_b = v_ids[i:j].to(device)
t_b = t_t[i:j].to(device)
ef_b = ef_t[i:j].to(device)
h_u = memory[u_b].clone()
h_v = memory[v_b].clone()
dt_u = (t_b - last_t[u_b]).clamp(min=0.0)
dt_v = (t_b - last_t[v_b]).clamp(min=0.0)
h_u_proj = module.project(h_u, dt_u, module.W_proj_u)
h_v_proj = module.project(h_v, dt_v, module.W_proj_v)
h_u_new = module.update(h_u_proj, h_v_proj, ef_b, module.rnn_u, module.norm_u)
h_v_new = module.update(h_v_proj, h_u_proj, ef_b, module.rnn_v, module.norm_v)
both = torch.cat([u_b, v_b])
both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0)
unique_ids, inv = torch.unique(both, return_inverse=True)
agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=device)
agg_h.index_add_(0, inv, both_h)
cnt = torch.bincount(inv).unsqueeze(1).float()
memory[unique_ids] = agg_h / cnt
last_t[u_b] = t_b
last_t[v_b] = t_b
eval_t = torch.tensor(
[min(n, self._n_nodes - 1) for n in eval_nodes],
dtype=torch.long, device=device,
)
node_emb = memory[eval_t]
# Guard: init classifier if train_node_classifier was never called
if not hasattr(self, "_node_clf") or self._node_clf is None:
self._node_clf = nn.Sequential(
nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1)
).to(device)
with torch.no_grad():
probs = torch.sigmoid(self._node_clf(node_emb).squeeze(-1)).cpu().numpy()
return probs.astype(np.float32)
def extract_prefix_embeddings(
self,
df_eval: pd.DataFrame,
examples: pd.DataFrame,
) -> np.ndarray:
assert self._module is not None, "Call fit() first."
if examples.empty:
return np.zeros((0, self.memory_dim), dtype=np.float32)
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy()
if "local_event_idx" not in df_eval.columns:
df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32)
capture_map: dict[tuple[int, int], list[int]] = {}
for ex_idx, row in enumerate(examples.itertuples(index=False)):
key = (int(row.sender_id), int(row.eval_local_event_idx))
capture_map.setdefault(key, []).append(ex_idx)
max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1
memory = torch.zeros(max(self._n_nodes, max_seen_id), self.memory_dim, device=self.device)
last_t = torch.zeros(max(self._n_nodes, max_seen_id), device=self.device)
ns = self._norm_stats
module = self._module
ef_np = build_edge_features(df_eval).astype(np.float32)
ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"]
t_vals = df_eval["timestamp"].to_numpy(dtype=np.float32)
t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6)
out = np.zeros((len(examples), self.memory_dim), dtype=np.float32)
module.eval()
with torch.no_grad():
for idx, row in enumerate(df_eval.itertuples(index=False)):
u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=self.device)
v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=self.device)
t = torch.tensor([t_norm[idx]], dtype=torch.float32, device=self.device)
ef = torch.tensor(ef_np[idx:idx + 1], dtype=torch.float32, device=self.device)
h_u = memory[u].clone()
h_v = memory[v].clone()
dt_u = (t - last_t[u]).clamp(min=0.0)
dt_v = (t - last_t[v]).clamp(min=0.0)
h_u_proj = module.project(h_u, dt_u, module.W_proj_u)
h_v_proj = module.project(h_v, dt_v, module.W_proj_v)
h_u_new = module.update(h_u_proj, h_v_proj, ef, module.rnn_u, module.norm_u)
h_v_new = module.update(h_v_proj, h_u_proj, ef, module.rnn_v, module.norm_v)
both_ids = torch.cat([u, v])
both_h = torch.nan_to_num(torch.cat([h_u_new, h_v_new], dim=0), nan=0.0)
unique_ids, inv = torch.unique(both_ids, return_inverse=True)
agg_h = torch.zeros(len(unique_ids), self.memory_dim, device=self.device)
agg_h.index_add_(0, inv, both_h)
cnt = torch.bincount(inv).unsqueeze(1).float()
memory[unique_ids] = agg_h / cnt
last_t[u] = t
last_t[v] = t
key = (int(row.sender_id), int(row.local_event_idx))
if key in capture_map:
emb = memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32)
for ex_idx in capture_map[key]:
out[ex_idx] = emb
return out
# ------------------------------------------------------------------ #
def reset_memory(self) -> None:
if self._memory is not None:
self._memory.zero_()
self._last_t.zero_()
# ------------------------------------------------------------------ #
def train_node_classifier(
self, eval_nodes: List[int], y_labels: np.ndarray, num_epochs: int = 150
) -> None:
device = self.device
eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
node_emb = self._memory[eval_t].detach()
y = torch.tensor(y_labels, dtype=torch.float32, device=device)
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3)
self._node_clf.train()
for _ in range(num_epochs):
logits = self._node_clf(node_emb).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
self._node_clf.eval()