File size: 2,145 Bytes
a3682cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from src.gnn.edge_dataset import EdgeDataset
from src.gnn.model import EdgeGNN


def train_gnn(graph_data):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device)
    edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long).to(device)
    edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32).to(device)
    y = torch.tensor(graph_data["y"], dtype=torch.float32).to(device)

    # Normalize ALL features
    x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
    edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-6)

    train_mask = graph_data["train_mask"]
    if hasattr(train_mask, 'values'):
        train_mask = train_mask.values
    train_idx = np.where(train_mask)[0]

    train_edge_index = edge_index[:, train_idx]

    dataset = EdgeDataset(edge_index, edge_attr, y, train_idx)
    loader = DataLoader(dataset, batch_size=4096, shuffle=True)

    model = EdgeGNN(
        in_channels=x.shape[1],
        hidden_dim=64,
        edge_dim=edge_attr.shape[1],
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Capped pos_weight
    raw_pw = (y == 0).sum().float() / (y == 1).sum().float()
    pos_weight = torch.clamp(raw_pw, max=10.0)
    loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    for epoch in range(5):
        total_loss = 0

        for batch in loader:
            src = batch["src"].to(device)
            dst = batch["dst"].to(device)
            edge_feat = batch["edge_attr"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()

            logits = model(x, train_edge_index, edge_feat, src, dst)

            loss = loss_fn(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch} Loss: {total_loss:.4f}")

    return model