temporal-twins-code / models /tgn_wrapper.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
"""
models/tgn_wrapper.py
=====================
Wraps the existing src/tgn/ pipeline behind the TemporalModel interface.
Architecture (unchanged from src/tgn/model.py):
- GRU-based memory module
- Message MLP (memory × 2 + edge + time → memory)
- Node classifier head: memory + static_feat → fraud prob
"""
from __future__ import annotations
import copy
from typing import List
import numpy as np
import pandas as pd
import torch
from models.base import TemporalModel
from src.graph.dataset_builder import build_graph_dataset
from src.graph.graph_builder import build_edge_features
from src.tgn.memory import Memory
from src.tgn.model import TGN
from src.tgn.time_encoding import TimeEncoding
from src.tgn.train import train_tgn
class TGNWrapper(TemporalModel):
"""TGN with GRU memory, wrapped behind the unified TemporalModel interface."""
def __init__(
self,
memory_dim: int = 64,
time_dim: int = 16,
hidden_dim: int = 128,
device: str = "cpu",
):
self.memory_dim = memory_dim
self.time_dim = time_dim
self.hidden_dim = hidden_dim
self.device = torch.device(device)
# filled by fit()
self._model: TGN | None = None
self._memory: Memory | None = None
self._time_encoder: TimeEncoding | None = None
self._norm_stats: dict | None = None
self._num_nodes: int = 0
self._users: pd.DataFrame | None = None
self._node_head_fitted = False
@property
def name(self) -> str:
return "TGN"
# ------------------------------------------------------------------ #
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
df_train = df_train.sort_values("timestamp").reset_index(drop=True)
# build_graph_dataset expects a users DataFrame; derive a minimal one
users = _make_users_df(df_train)
self._users = users
graph_data = build_graph_dataset(df_train, users)
# Override train_mask to use ALL training events
graph_data["train_mask"] = np.ones(len(df_train), dtype=bool)
self._model, self._memory, self._time_encoder, self._norm_stats = train_tgn(
graph_data, num_epochs=num_epochs
)
self._num_nodes = self._memory.memory.shape[0]
# ------------------------------------------------------------------ #
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
assert self._model is not None, "Call fit() first."
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True)
device = self.device
model = self._model
memory = self._memory
time_encoder = self._time_encoder
ns = self._norm_stats
# Warm-up: pass eval events through memory (no label access)
edge_index = torch.tensor(
np.vstack([df_eval["sender_id"].values, df_eval["receiver_id"].values]),
dtype=torch.long,
)
edge_attr = torch.tensor(
build_edge_features(df_eval), dtype=torch.float32
)
edge_attr = (edge_attr - ns["ea_mean"]) / ns["ea_std"]
timestamps = torch.tensor(df_eval["timestamp"].values, dtype=torch.float32)
timestamps = (timestamps - ns["t_min"]) / (ns["t_max"] - ns["t_min"] + 1e-6)
batch_size = 1024
model.eval()
with torch.no_grad():
for i in range(0, len(df_eval), batch_size):
ids = range(i, min(i + batch_size, len(df_eval)))
u = edge_index[0, ids].to(device)
v = edge_index[1, ids].to(device)
ef = edge_attr[ids].to(device)
t = timestamps[ids].to(device) * 5.0
time_enc = time_encoder(t)
h_u = memory.get(u)
h_v = memory.get(v)
msg = model.compute_message(h_u, h_v, ef, time_enc)
node_ids = torch.cat([u, v])
messages = torch.cat([msg, msg])
unique_nodes, inv = torch.unique(node_ids, return_inverse=True)
agg = torch.zeros_like(memory.memory[unique_nodes])
agg.index_add_(0, inv, messages)
counts = torch.bincount(inv).unsqueeze(1)
memory.update(unique_nodes, agg / counts)
# Score eval nodes (clamp to valid range for OOD nodes)
eval_nodes_clamped = [min(n, self._num_nodes - 1) for n in eval_nodes]
eval_nodes_t = torch.tensor(eval_nodes_clamped, dtype=torch.long, device=device)
node_emb = memory.memory[eval_nodes_t].clone()
x_zeros = torch.zeros(len(eval_nodes), ns["x"].shape[1], device=device)
model.eval()
with torch.no_grad():
combined = torch.cat([node_emb, x_zeros], dim=1)
probs = torch.sigmoid(
model.node_classifier(combined).squeeze(-1)
).cpu().numpy()
return probs.astype(np.float32)
def extract_prefix_embeddings(
self,
df_eval: pd.DataFrame,
examples: pd.DataFrame,
) -> np.ndarray:
assert self._model is not None, "Call fit() first."
if examples.empty:
return np.zeros((0, self.memory_dim), dtype=np.float32)
df_eval = df_eval.sort_values("timestamp").reset_index(drop=True).copy()
if "local_event_idx" not in df_eval.columns:
df_eval["local_event_idx"] = df_eval.groupby("sender_id").cumcount().astype(np.int32)
capture_map: dict[tuple[int, int], list[int]] = {}
for ex_idx, row in enumerate(examples.itertuples(index=False)):
key = (int(row.sender_id), int(row.eval_local_event_idx))
capture_map.setdefault(key, []).append(ex_idx)
max_seen_id = int(max(df_eval["sender_id"].max(), df_eval["receiver_id"].max())) + 1
num_nodes = max(self._num_nodes, max_seen_id)
device = self.device
model = self._model
time_encoder = self._time_encoder
ns = self._norm_stats
memory = Memory(num_nodes, memory_dim=self.memory_dim, device=device)
ea_mean = ns["ea_mean"].detach().cpu().numpy() if isinstance(ns["ea_mean"], torch.Tensor) else np.asarray(ns["ea_mean"], dtype=np.float32)
ea_std = ns["ea_std"].detach().cpu().numpy() if isinstance(ns["ea_std"], torch.Tensor) else np.asarray(ns["ea_std"], dtype=np.float32)
t_min = float(ns["t_min"].item()) if isinstance(ns["t_min"], torch.Tensor) else float(ns["t_min"])
t_max = float(ns["t_max"].item()) if isinstance(ns["t_max"], torch.Tensor) else float(ns["t_max"])
edge_attr = build_edge_features(df_eval).astype(np.float32)
edge_attr = (edge_attr - ea_mean) / ea_std
timestamps = df_eval["timestamp"].to_numpy(dtype=np.float32)
timestamps = (timestamps - t_min) / (t_max - t_min + 1e-6)
timestamps = timestamps * 5.0
out = np.zeros((len(examples), self.memory_dim), dtype=np.float32)
model.eval()
with torch.no_grad():
for idx, row in enumerate(df_eval.itertuples(index=False)):
u = torch.tensor([int(row.sender_id)], dtype=torch.long, device=device)
v = torch.tensor([int(row.receiver_id)], dtype=torch.long, device=device)
ef = torch.tensor(edge_attr[idx:idx + 1], dtype=torch.float32, device=device)
t = torch.tensor([timestamps[idx]], dtype=torch.float32, device=device)
time_enc = time_encoder(t)
h_u = memory.get(u)
h_v = memory.get(v)
msg = model.compute_message(h_u, h_v, ef, time_enc)
node_ids = torch.cat([u, v])
messages = torch.cat([msg, msg], dim=0)
unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True)
agg_msg = torch.zeros((len(unique_nodes), self.memory_dim), device=device)
agg_msg.index_add_(0, inverse_idx, messages)
counts = torch.bincount(inverse_idx).unsqueeze(1).float()
memory.update(unique_nodes, agg_msg / counts)
key = (int(row.sender_id), int(row.local_event_idx))
if key in capture_map:
emb = memory.memory[int(row.sender_id)].detach().cpu().numpy().astype(np.float32)
for ex_idx in capture_map[key]:
out[ex_idx] = emb
return out
# ------------------------------------------------------------------ #
def reset_memory(self) -> None:
if self._memory is not None:
self._memory.memory.zero_()
# ------------------------------------------------------------------ #
def _train_node_head(
self,
eval_nodes: List[int],
y_train: np.ndarray,
num_epochs: int = 100,
) -> None:
"""Fine-tune the node classifier head on training labels."""
assert self._model is not None
device = self.device
model = self._model
memory = self._memory
eval_nodes_t = torch.tensor(eval_nodes, dtype=torch.long, device=device)
x = torch.zeros(len(eval_nodes), self._norm_stats["x"].shape[1], device=device)
y = torch.tensor(y_train, dtype=torch.float32, device=device)
saw_grad = False
for p in model.parameters():
p.requires_grad = False
for p in model.node_classifier.parameters():
p.requires_grad = True
opt = torch.optim.Adam(model.node_classifier.parameters(), lr=1e-3)
pw = torch.clamp((y == 0).sum() / ((y == 1).sum() + 1e-6), max=10.0)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pw)
model.train()
for _ in range(num_epochs):
node_emb = memory.memory[eval_nodes_t].detach()
combined = torch.cat([node_emb, x], dim=1)
logits = model.node_classifier(combined).squeeze(-1)
loss = loss_fn(logits, y)
opt.zero_grad()
loss.backward()
saw_grad = saw_grad or any(
p.grad is not None and torch.isfinite(p.grad).all()
for p in model.node_classifier.parameters()
)
opt.step()
for p in model.parameters():
p.requires_grad = True
assert saw_grad, "TGN node classifier did not receive gradients."
self._node_head_fitted = True
def train_node_classifier(
self,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 100,
) -> None:
self._train_node_head(eval_nodes, y_labels, num_epochs=num_epochs)
# ------------------------------------------------------------------ #
# Helpers #
# ------------------------------------------------------------------ #
def _make_users_df(df: pd.DataFrame) -> pd.DataFrame:
"""Create a minimal users DataFrame from sender_ids in df."""
max_id = int(max(df["sender_id"].max(), df["receiver_id"].max()))
return pd.DataFrame({"user_id": np.arange(max_id + 1, dtype=np.int64)})