temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from src.gnn.edge_dataset import EdgeDataset
from src.gnn.model import EdgeGNN
def train_gnn(graph_data):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device)
edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device)
edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device)
y = torch.tensor(graph_data["y"], dtype=torch.float32).to(device)
# Normalize ALL features
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-6)
train_mask = graph_data["train_mask"]
if hasattr(train_mask, 'values'):
train_mask = train_mask.values
train_idx = np.where(train_mask)[0]
train_edge_index = edge_index[:, train_idx]
dataset = EdgeDataset(edge_index, edge_attr, y, train_idx)
loader = DataLoader(dataset, batch_size=4096, shuffle=True)
model = EdgeGNN(
in_channels=x.shape[1],
hidden_dim=64,
edge_dim=edge_attr.shape[1],
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Capped pos_weight
raw_pw = (y == 0).sum().float() / (y == 1).sum().float()
pos_weight = torch.clamp(raw_pw, max=10.0)
loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
for epoch in range(5):
total_loss = 0
for batch in loader:
src = batch["src"].to(device)
dst = batch["dst"].to(device)
edge_feat = batch["edge_attr"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
logits = model(x, train_edge_index, edge_feat, src, dst)
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch} Loss: {total_loss:.4f}")
return model