import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from src.gnn.edge_dataset import EdgeDataset from src.gnn.model import EdgeGNN def train_gnn(graph_data): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device) edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device) y = torch.tensor(graph_data["y"], dtype=torch.float32).to(device) # Normalize ALL features x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-6) train_mask = graph_data["train_mask"] if hasattr(train_mask, 'values'): train_mask = train_mask.values train_idx = np.where(train_mask)[0] train_edge_index = edge_index[:, train_idx] dataset = EdgeDataset(edge_index, edge_attr, y, train_idx) loader = DataLoader(dataset, batch_size=4096, shuffle=True) model = EdgeGNN( in_channels=x.shape[1], hidden_dim=64, edge_dim=edge_attr.shape[1], ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Capped pos_weight raw_pw = (y == 0).sum().float() / (y == 1).sum().float() pos_weight = torch.clamp(raw_pw, max=10.0) loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) for epoch in range(5): total_loss = 0 for batch in loader: src = batch["src"].to(device) dst = batch["dst"].to(device) edge_feat = batch["edge_attr"].to(device) labels = batch["label"].to(device) optimizer.zero_grad() logits = model(x, train_edge_index, edge_feat, src, dst) loss = loss_fn(logits, labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() print(f"Epoch {epoch} Loss: {total_loss:.4f}") return model