temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
import torch
import numpy as np
from tqdm import tqdm
from src.tgn.memory import Memory
from src.tgn.model import TGN
from src.tgn.time_encoding import TimeEncoding
def train_tgn(graph_data, batch_size=1024, num_epochs=3):
device = torch.device("cpu")
edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long)
edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)
labels = torch.tensor(graph_data["y"], dtype=torch.float32)
x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device)
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
# Normalize ALL edge features
ea_mean = edge_attr.mean(dim=0)
ea_std = edge_attr.std(dim=0) + 1e-6
edge_attr = (edge_attr - ea_mean) / ea_std
# Raw timestamps for time encoder (normalized to [0, 1])
timestamps_raw = torch.tensor(graph_data["timestamps"], dtype=torch.float32)
t_min = timestamps_raw.min()
t_max = timestamps_raw.max()
timestamps = (timestamps_raw - t_min) / (t_max - t_min + 1e-6)
num_nodes = x.shape[0]
model = TGN(
memory_dim=64,
node_dim=x.shape[1],
edge_dim=edge_attr.shape[1],
time_dim=16
).to(device)
time_encoder = TimeEncoding(16).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Capped pos_weight
raw_pw = (labels == 0).sum().float() / (labels == 1).sum().float()
pos_weight = torch.clamp(raw_pw, max=10.0)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# Train on train split ONLY
train_mask = graph_data["train_mask"]
if isinstance(train_mask, np.ndarray):
train_mask = torch.tensor(train_mask, dtype=torch.bool)
else:
train_mask = torch.tensor(train_mask.values, dtype=torch.bool)
train_idx = torch.where(train_mask)[0]
N = len(train_idx)
for epoch in range(num_epochs):
total_loss = 0
memory = Memory(num_nodes, memory_dim=64, device=device)
for i in tqdm(range(0, N, batch_size)):
batch_ids = train_idx[i:i + batch_size]
u = edge_index[0, batch_ids].to(device)
v = edge_index[1, batch_ids].to(device)
edge_feat = edge_attr[batch_ids].to(device)
t = timestamps[batch_ids].to(device) * 5.0 # Amplify time differences to force causality
labels_batch = labels[batch_ids].to(device)
h_u = memory.get(u)
h_v = memory.get(v)
x_u = x[u]
x_v = x[v]
time_enc = time_encoder(t)
logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc)
logits = torch.clamp(logits, -10, 10)
loss = loss_fn(logits, labels_batch)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
# Update memory
h_u_new = memory.get(u)
h_v_new = memory.get(v)
msg = model.compute_message(
h_u_new.detach(),
h_v_new.detach(),
edge_feat,
time_enc
)
node_ids = torch.cat([u, v])
messages = torch.cat([msg, msg])
unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True)
agg_msg = torch.zeros_like(memory.memory[unique_nodes])
agg_msg.index_add_(0, inverse_idx, messages)
counts = torch.bincount(inverse_idx).unsqueeze(1)
agg_msg = agg_msg / counts
memory.update(unique_nodes, agg_msg)
print(f"Epoch {epoch} Loss: {total_loss:.4f}")
# Build memory for inference edges (zero-shot test distribution)
if "inference_mask" in graph_data:
inf_mask = graph_data["inference_mask"]
if isinstance(inf_mask, np.ndarray):
inf_mask = torch.tensor(inf_mask, dtype=torch.bool)
else:
inf_mask = torch.tensor(inf_mask.values, dtype=torch.bool)
inf_idx = torch.where(inf_mask)[0]
model.eval()
with torch.no_grad():
for i in range(0, len(inf_idx), batch_size):
batch_ids = inf_idx[i:i + batch_size]
u = edge_index[0, batch_ids].to(device)
v = edge_index[1, batch_ids].to(device)
edge_feat = edge_attr[batch_ids].to(device)
t = timestamps[batch_ids].to(device) * 5.0
time_enc = time_encoder(t)
h_u_new = memory.get(u)
h_v_new = memory.get(v)
msg = model.compute_message(h_u_new, h_v_new, edge_feat, time_enc)
node_ids = torch.cat([u, v])
messages = torch.cat([msg, msg])
unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True)
agg_msg = torch.zeros_like(memory.memory[unique_nodes])
agg_msg.index_add_(0, inverse_idx, messages)
counts = torch.bincount(inverse_idx).unsqueeze(1)
memory.update(unique_nodes, agg_msg / counts)
norm_stats = {
"ea_mean": ea_mean, "ea_std": ea_std,
"t_min": t_min, "t_max": t_max,
"x": x,
}
return model, memory, time_encoder, norm_stats