temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
models/tgat.py
==============
Temporal Graph Attention Network (TGAT)
Xu et al., "Inductive Representation Learning on Temporal Graphs" (ICLR 2020)
Architecture
------------
- Sinusoidal time encoding (reuses src/tgn/time_encoding.py)
- Per-node ring buffer of K most recent temporal neighbors
- Multi-head scaled dot-product attention over temporal neighborhood
- GRU-cell aggregator updates node memory after each event
- Node classifier head: memory → fraud probability
Event processing (streaming, chronological):
For each edge (u, v, t, edge_feat):
1. Retrieve last K neighbors of u from buffer → {(t_i, h_i, e_i)}
2. Build query: Q = W_q(cat(h_u, φ(0))) [current state at t]
Build keys: K = W_k(cat(h_i, φ(t−t_i))) [neighbor state at t_i]
Build vals: V = W_v(cat(h_i, e_i, φ(t−t_i))) [neighbor context]
3. attn = softmax(Q K^T / √d), z = attn·V
4. h_u ← GRU(z, h_u) [update sender memory]
5. Symmetrically update h_v using u's neighborhood
6. Append (t, h_u, h_v, e) to neighbor buffers
"""
from __future__ import annotations
from collections import defaultdict
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.base import TemporalModel
from models.tgn_wrapper import _make_users_df
from src.graph.graph_builder import build_edge_features
from src.tgn.time_encoding import TimeEncoding
# ------------------------------------------------------------------ #
# Core TGAT nn.Module #
# ------------------------------------------------------------------ #
class _TGATModule(nn.Module):
def __init__(
self,
memory_dim: int,
edge_dim: int,
time_dim: int,
num_heads: int,
):
super().__init__()
self.memory_dim = memory_dim
self.time_enc = TimeEncoding(time_dim)
# Input dimensions after concatenation
q_in = memory_dim + 2 * time_dim # h_u || φ(0)
kv_base = memory_dim + 2 * time_dim # h_nbr || φ(dt)
v_in = memory_dim + edge_dim + 2 * time_dim # h_nbr || e || φ(dt)
self.attn_dim = memory_dim # output of attention
self.num_heads = num_heads
assert self.attn_dim % num_heads == 0, "attn_dim must be divisible by num_heads"
self.W_q = nn.Linear(q_in, self.attn_dim, bias=False)
self.W_k = nn.Linear(kv_base, self.attn_dim, bias=False)
self.W_v = nn.Linear(v_in, self.attn_dim, bias=False)
self.scale = (self.attn_dim // num_heads) ** -0.5
# Merge attended output with current memory
self.merge = nn.Linear(self.attn_dim + memory_dim, memory_dim)
self.gru = nn.GRUCell(memory_dim, memory_dim)
# Node classifier
self.classifier = nn.Sequential(
nn.Linear(memory_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def attend(
self,
h_u: torch.Tensor, # (B, memory_dim) — current node state
h_nbrs: torch.Tensor, # (B, K, memory_dim)
e_nbrs: torch.Tensor, # (B, K, edge_dim)
dt_nbrs: torch.Tensor, # (B, K) — time deltas
mask: torch.Tensor, # (B, K) bool — True = valid
) -> torch.Tensor:
"""Compute multi-head attention over temporal neighborhood."""
B, K = dt_nbrs.shape
H = self.num_heads
d_h = self.attn_dim // H
phi_0 = self.time_enc(torch.zeros(B, device=h_u.device)) # (B, 2*time_dim)
phi_dt = self.time_enc(dt_nbrs.reshape(-1)).reshape(B, K, -1) # (B, K, 2*time_dim)
# Query
q_in = torch.cat([h_u, phi_0], dim=-1) # (B, q_in)
Q = self.W_q(q_in).view(B, H, d_h) # (B, H, d_h)
# Key
h_nbrs_flat = h_nbrs.reshape(B * K, -1)
phi_dt_flat = phi_dt.reshape(B * K, -1)
k_in = torch.cat([h_nbrs_flat, phi_dt_flat], dim=-1) # (B*K, kv)
K_ = self.W_k(k_in).view(B, K, H, d_h) # (B, K, H, d_h)
K_ = K_.permute(0, 2, 1, 3) # (B, H, K, d_h)
# Value
v_in = torch.cat([h_nbrs_flat, e_nbrs.reshape(B * K, -1), phi_dt_flat], dim=-1)
V = self.W_v(v_in).view(B, K, H, d_h)
V = V.permute(0, 2, 1, 3) # (B, H, K, d_h)
# Attention scores
scores = (Q.unsqueeze(2) @ K_.transpose(-2, -1)).squeeze(2) # (B, H, K)
scores = scores * self.scale
# Mask invalid neighbors (padding)
if mask is not None:
inv_mask = ~mask.unsqueeze(1) # (B, 1, K)
scores = scores.masked_fill(inv_mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
attn = torch.nan_to_num(attn, nan=0.0) # handle all-masked rows
# Weighted sum
z = (attn.unsqueeze(-1) * V).sum(dim=2) # (B, H, d_h)
z = z.reshape(B, self.attn_dim) # (B, attn_dim)
return z
def update(self, h_u: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
merged = self.merge(torch.cat([z, h_u], dim=-1))
return self.gru(merged, h_u)
def classify(self, memory: torch.Tensor) -> torch.Tensor:
return self.classifier(memory).squeeze(-1)
# ------------------------------------------------------------------ #
# TGAT Streamer (event-level memory management) #
# ------------------------------------------------------------------ #
class _TGATStreamer:
"""
Maintains per-node memory and temporal neighbor buffers.
Processes events in a batched manner (approximate — same-batch
events use pre-batch memory state, standard practice for scalability).
"""
def __init__(
self,
module: _TGATModule,
n_nodes: int,
memory_dim: int,
edge_dim: int,
n_neighbors: int,
device: torch.device,
):
self.module = module
self.memory_dim = memory_dim
self.edge_dim = edge_dim
self.n_neighbors = n_neighbors
self.device = device
# Node memory: (n_nodes, memory_dim)
self.memory = torch.zeros(n_nodes, memory_dim, device=device)
# Per-node circular neighbor buffer: stores (time, h_nbr, edge_feat) tuples
# Stored as plain Python lists for flexibility; trimmed to n_neighbors
self.nbr_times: List[List[float]] = [[] for _ in range(n_nodes)]
self.nbr_h: List[List[torch.Tensor]] = [[] for _ in range(n_nodes)]
self.nbr_e: List[List[torch.Tensor]] = [[] for _ in range(n_nodes)]
def _write_memory_rows(
self,
node_ids: torch.Tensor,
values: torch.Tensor,
) -> None:
"""Deterministic last-write-wins update for repeated node ids in a batch."""
for idx in range(len(node_ids)):
self.memory[int(node_ids[idx].item())] = values[idx].detach()
def _get_neighbor_tensors(
self, node_ids: torch.Tensor
):
"""
Returns padded (h_nbrs, e_nbrs, dt_nbrs, mask) for a batch of nodes.
"""
B = len(node_ids)
K = self.n_neighbors
mem_dim = self.memory_dim
e_dim = self.edge_dim
device = self.device
h_out = torch.zeros(B, K, mem_dim, device=device)
e_out = torch.zeros(B, K, e_dim, device=device)
dt_out = torch.zeros(B, K, device=device)
mask = torch.zeros(B, K, dtype=torch.bool, device=device)
# Use current timestamp == max in buf (approximate, fine for inference)
# We'll pass dt as a separate tensor
return h_out, e_out, dt_out, mask
def _fill_neighbor_batch(
self,
node_ids: torch.Tensor,
current_times: torch.Tensor,
):
"""
Fills neighbor tensors for a batch, using the stored per-node buffers.
"""
B = len(node_ids)
K = self.n_neighbors
mem_dim = self.memory_dim
e_dim = self.edge_dim
device = self.device
h_out = torch.zeros(B, K, mem_dim, device=device)
e_out = torch.zeros(B, K, e_dim, device=device)
dt_out = torch.zeros(B, K, device=device)
mask = torch.zeros(B, K, dtype=torch.bool, device=device)
node_ids_np = node_ids.cpu().numpy()
times_np = current_times.cpu().numpy()
for b_idx, (nid, t_cur) in enumerate(zip(node_ids_np, times_np)):
buf_t = self.nbr_times[nid]
buf_h = self.nbr_h[nid]
buf_e = self.nbr_e[nid]
n_valid = len(buf_t)
if n_valid == 0:
continue
n_use = min(n_valid, K)
# Most recent K neighbors
for k, i in enumerate(range(n_valid - n_use, n_valid)):
h_out[b_idx, k] = buf_h[i]
e_out[b_idx, k] = buf_e[i]
dt_out[b_idx, k] = max(0.0, float(t_cur) - float(buf_t[i]))
mask[b_idx, k] = True
return h_out, e_out, dt_out, mask
def _update_buffers(
self,
node_ids_np: np.ndarray,
times_np: np.ndarray,
h_others: torch.Tensor, # (N, mem_dim) — embedding of the other node
edge_feats: torch.Tensor, # (N, edge_dim)
):
"""Add events to per-node neighbor buffers (detached)."""
for i, nid in enumerate(node_ids_np):
self.nbr_times[nid].append(float(times_np[i]))
self.nbr_h[nid].append(h_others[i].detach().cpu())
self.nbr_e[nid].append(edge_feats[i].detach().cpu())
# Trim
if len(self.nbr_times[nid]) > self.n_neighbors:
self.nbr_times[nid].pop(0)
self.nbr_h[nid].pop(0)
self.nbr_e[nid].pop(0)
def process_batch(
self,
u_ids: torch.Tensor, # (B,)
v_ids: torch.Tensor, # (B,)
times: torch.Tensor, # (B,) normalised
edge_feats: torch.Tensor, # (B, edge_dim)
compute_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Process a batch of events, update memory, return (logits_u, logits_v)
for training (edge-level fraud prediction used only during training).
"""
device = self.device
module = self.module
# Current memory state (detach to avoid BPTT through the buffer)
h_u = self.memory[u_ids].clone() # (B, mem_dim)
h_v = self.memory[v_ids].clone() # (B, mem_dim)
u_np = u_ids.cpu().numpy()
v_np = v_ids.cpu().numpy()
t_np = times.cpu().numpy()
# ---- Attend for u ----
h_nbrs_u, e_nbrs_u, dt_u, mask_u = self._fill_neighbor_batch(u_ids, times)
z_u = module.attend(h_u, h_nbrs_u, e_nbrs_u, dt_u, mask_u)
h_u_new = module.update(h_u.detach(), z_u)
# ---- Attend for v ----
h_nbrs_v, e_nbrs_v, dt_v, mask_v = self._fill_neighbor_batch(v_ids, times)
z_v = module.attend(h_v, h_nbrs_v, e_nbrs_v, dt_v, mask_v)
h_v_new = module.update(h_v.detach(), z_v)
# Write back in a deterministic order when a node appears multiple times.
self._write_memory_rows(u_ids, h_u_new)
self._write_memory_rows(v_ids, h_v_new)
# Update neighbor buffers
self._update_buffers(u_np, t_np, h_v_new, edge_feats)
self._update_buffers(v_np, t_np, h_u_new, edge_feats)
return h_u_new, h_v_new
def reset(self):
self.memory.zero_()
self.nbr_times = [[] for _ in range(self.memory.shape[0])]
self.nbr_h = [[] for _ in range(self.memory.shape[0])]
self.nbr_e = [[] for _ in range(self.memory.shape[0])]
# ------------------------------------------------------------------ #
# TGATWrapper (TemporalModel interface) #
# ------------------------------------------------------------------ #
class TGATWrapper(TemporalModel):
"""TGAT wrapped behind the unified TemporalModel interface."""
def __init__(
self,
memory_dim: int = 64,
time_dim: int = 8,
num_heads: int = 4,
n_neighbors: int = 10,
device: str = "cpu",
):
self.memory_dim = memory_dim
self.time_dim = time_dim
self.num_heads = num_heads
self.n_neighbors = n_neighbors
self.device = torch.device(device)
self._module: _TGATModule | None = None
self._streamer: _TGATStreamer | None = None
self._norm_stats: dict | None = None
self._n_nodes: int = 0
self._edge_dim: int = 0
@property
def name(self) -> str:
return "TGAT"
# ------------------------------------------------------------------ #
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
df_train = df_train.sort_values("timestamp").reset_index(drop=True)
# Pre-compute edge features
edge_feats_np = build_edge_features(df_train) # (N, edge_dim)
edge_dim = edge_feats_np.shape[1]
self._edge_dim = edge_dim
# Normalise
ea_mean = edge_feats_np.mean(axis=0)
ea_std = edge_feats_np.std(axis=0) + 1e-6
edge_feats_np = (edge_feats_np - ea_mean) / ea_std
# Timestamps (normalise to [0,1] then amplify)
t_vals = df_train["timestamp"].values.astype(np.float32)
t_min, t_max = t_vals.min(), t_vals.max()
t_norm = (t_vals - t_min) / (t_max - t_min + 1e-6)
self._norm_stats = {
"ea_mean": ea_mean, "ea_std": ea_std,
"t_min": t_min, "t_max": t_max,
}
# Node universe
all_nodes = np.union1d(
df_train["sender_id"].values, df_train["receiver_id"].values
)
n_nodes = int(all_nodes.max()) + 1
self._n_nodes = n_nodes
# Build module and streamer
module = _TGATModule(
memory_dim=self.memory_dim,
edge_dim=edge_dim,
time_dim=self.time_dim,
num_heads=self.num_heads,
).to(self.device)
self._module = module
streamer = _TGATStreamer(
module=module,
n_nodes=n_nodes,
memory_dim=self.memory_dim,
edge_dim=edge_dim,
n_neighbors=self.n_neighbors,
device=self.device,
)
self._streamer = streamer
# Labels (edge-level)
y = torch.tensor(df_train["is_fraud"].values, dtype=torch.float32)
u_ids = torch.tensor(df_train["sender_id"].values, dtype=torch.long)
v_ids = torch.tensor(df_train["receiver_id"].values, dtype=torch.long)
ef_all = torch.tensor(edge_feats_np, dtype=torch.float32)
t_all = torch.tensor(t_norm * 5.0, dtype=torch.float32)
raw_pw = (y == 0).sum() / ((y == 1).sum() + 1e-6)
pos_weight = torch.clamp(raw_pw, max=10.0).to(self.device)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimiser = torch.optim.Adam(module.parameters(), lr=1e-3)
# Edge-level loss: predict fraud for events where u is sender
# (proxy training signal; node classifier fine-tuned separately)
edge_classifier = nn.Sequential(
nn.Linear(self.memory_dim * 2 + edge_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
).to(self.device)
self._edge_clf = edge_classifier
optimiser.add_param_group({"params": edge_classifier.parameters()})
batch_size = 512
N = len(df_train)
for epoch in range(num_epochs):
# Re-initialise memory each epoch to avoid over-fitting to order
streamer.reset()
total_loss = 0.0
for i in range(0, N, batch_size):
j = min(i + batch_size, N)
u_b = u_ids[i:j].to(self.device)
v_b = v_ids[i:j].to(self.device)
t_b = t_all[i:j].to(self.device)
ef_b = ef_all[i:j].to(self.device)
y_b = y[i:j].to(self.device)
h_u, h_v = streamer.process_batch(u_b, v_b, t_b, ef_b)
edge_in = torch.cat([h_u, h_v, ef_b], dim=-1)
logits = edge_classifier(edge_in).squeeze(-1)
logits = torch.clamp(logits, -10, 10)
loss = loss_fn(logits, y_b)
optimiser.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(module.parameters(), 1.0)
optimiser.step()
total_loss += loss.item()
print(f"[TGAT] Epoch {epoch + 1}/{num_epochs} Loss: {total_loss:.4f}")
# Node classifier head (trained separately on node-level labels)
self._node_clf = nn.Sequential(
nn.Linear(self.memory_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
).to(self.device)
# ------------------------------------------------------------------ #
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
assert self._streamer is not None, "Call fit() first."
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True)
ns = self._norm_stats
ef_np = build_edge_features(df_eval).astype(np.float32)
ef_np = (ef_np - ns["ea_mean"]) / ns["ea_std"]
t_vals = df_eval["timestamp"].values.astype(np.float32)
t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6)
u_ids = torch.tensor(df_eval["sender_id"].values, dtype=torch.long)
v_ids = torch.tensor(df_eval["receiver_id"].values, dtype=torch.long)
ef_t = torch.tensor(ef_np, dtype=torch.float32)
t_t = torch.tensor(t_norm * 5.0, dtype=torch.float32)
self._module.eval()
with torch.no_grad():
batch_size = 512
for i in range(0, len(df_eval), batch_size):
j = min(i + batch_size, len(df_eval))
self._streamer.process_batch(
u_ids[i:j].to(self.device),
v_ids[i:j].to(self.device),
t_t[i:j].to(self.device),
ef_t[i:j].to(self.device),
compute_grad=False,
)
# Extract memory for eval nodes (clamp to valid range)
eval_t = torch.tensor(
[min(n, self._n_nodes - 1) for n in eval_nodes],
dtype=torch.long, device=self.device,
)
node_emb = self._streamer.memory[eval_t]
if not hasattr(self, "_node_clf") or self._node_clf is None:
self._node_clf = nn.Sequential(
nn.Linear(self.memory_dim, 64), nn.ReLU(), nn.Linear(64, 1)
).to(self.device)
with torch.no_grad():
logits = self._node_clf(node_emb).squeeze(-1)
probs = torch.sigmoid(logits).cpu().numpy()
return probs.astype(np.float32)
def extract_prefix_embeddings(
self,
df_eval: pd.DataFrame,
examples: pd.DataFrame,
) -> np.ndarray:
assert self._module is not None, "Call fit() first."
if examples.empty:
return np.zeros((0, self.memory_dim), dtype=np.float32)
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy()
if "local_event_idx" not in df_eval.columns:
df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32)
capture_map: dict[tuple[int, int], list[int]] = {}
for ex_idx, row in enumerate(examples.itertuples(index=False)):
key = (int(row.sender_id), int(row.eval_local_event_idx))
capture_map.setdefault(key, []).append(ex_idx)
max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1
streamer = _TGATStreamer(
module=self._module,
n_nodes=max(self._n_nodes, max_seen_id),
memory_dim=self.memory_dim,
edge_dim=self._edge_dim,
n_neighbors=self.n_neighbors,
device=self.device,
)
ns = self._norm_stats
edge_feats_np = build_edge_features(df_eval).astype(np.float32)
edge_feats_np = (edge_feats_np - ns["ea_mean"]) / ns["ea_std"]
t_vals = df_eval["timestamp"].to_numpy(dtype=np.float32)
t_norm = (t_vals - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6) * 5.0
out = np.zeros((len(examples), self.memory_dim), dtype=np.float32)
self._module.eval()
with torch.no_grad():
for idx, row in enumerate(df_eval.itertuples(index=False)):
u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=self.device)
v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=self.device)
t = torch.tensor([t_norm[idx]], dtype=torch.float32, device=self.device)
ef = torch.tensor(edge_feats_np[idx:idx + 1], dtype=torch.float32, device=self.device)
streamer.process_batch(u, v, t, ef, compute_grad=False)
key = (int(row.sender_id), int(row.local_event_idx))
if key in capture_map:
emb = streamer.memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32)
for ex_idx in capture_map[key]:
out[ex_idx] = emb
return out
# ------------------------------------------------------------------ #
def reset_memory(self) -> None:
if self._streamer is not None:
self._streamer.memory.zero_()
self._streamer.nbr_times = [[] for _ in range(self._n_nodes)]
self._streamer.nbr_h = [[] for _ in range(self._n_nodes)]
self._streamer.nbr_e = [[] for _ in range(self._n_nodes)]
# ------------------------------------------------------------------ #
def train_node_classifier(
self,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
"""Fine-tune node classifier on node-level labels from training window."""
device = self.device
eval_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
node_emb = self._streamer.memory[eval_t].detach()
y = torch.tensor(y_labels, dtype=torch.float32, device=device)
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
opt = torch.optim.Adam(self._node_clf.parameters(), lr=1e-3)
self._node_clf.train()
for _ in range(num_epochs):
logits = self._node_clf(node_emb).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
opt.step()
self._node_clf.eval()