| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from src.gnn.edge_dataset import EdgeDataset |
| from src.gnn.model import EdgeGNN |
|
|
|
|
| def train_gnn(graph_data): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device) |
| edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device) |
| edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device) |
| y = torch.tensor(graph_data["y"], dtype=torch.float32).to(device) |
|
|
| |
| x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6) |
| edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-6) |
|
|
| train_mask = graph_data["train_mask"] |
| if hasattr(train_mask, 'values'): |
| train_mask = train_mask.values |
| train_idx = np.where(train_mask)[0] |
|
|
| train_edge_index = edge_index[:, train_idx] |
|
|
| dataset = EdgeDataset(edge_index, edge_attr, y, train_idx) |
| loader = DataLoader(dataset, batch_size=4096, shuffle=True) |
|
|
| model = EdgeGNN( |
| in_channels=x.shape[1], |
| hidden_dim=64, |
| edge_dim=edge_attr.shape[1], |
| ).to(device) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
| |
| raw_pw = (y == 0).sum().float() / (y == 1).sum().float() |
| pos_weight = torch.clamp(raw_pw, max=10.0) |
| loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
| for epoch in range(5): |
| total_loss = 0 |
|
|
| for batch in loader: |
| src = batch["src"].to(device) |
| dst = batch["dst"].to(device) |
| edge_feat = batch["edge_attr"].to(device) |
| labels = batch["label"].to(device) |
|
|
| optimizer.zero_grad() |
|
|
| logits = model(x, train_edge_index, edge_feat, src, dst) |
|
|
| loss = loss_fn(logits, labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| print(f"Epoch {epoch} Loss: {total_loss:.4f}") |
|
|
| return model |