| """ |
| Graph Isomorphism Network (GIN) training on MUTAG dataset. |
| Demonstrates isomorphism-aware learning for graph classification. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import GINConv, global_add_pool |
| from torch_geometric.loader import DataLoader |
| from datasets import load_dataset |
| import numpy as np |
| from sklearn.model_selection import StratifiedKFold |
| from sklearn.metrics import accuracy_score |
| import random |
|
|
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
| random.seed(42) |
|
|
|
|
| def load_mutag_from_hf(): |
| """Load MUTAG dataset from HuggingFace and convert to PyG Data objects.""" |
| ds = load_dataset("graphs-datasets/MUTAG", split="train") |
| |
| from torch_geometric.data import Data |
| |
| data_list = [] |
| for row in ds: |
| edge_index = torch.tensor(row["edge_index"], dtype=torch.long) |
| x = torch.tensor(row["node_feat"], dtype=torch.float) |
| y = torch.tensor(row["y"], dtype=torch.long) |
| num_nodes = row["num_nodes"] |
| |
| |
| edge_attr = None |
| if row.get("edge_attr") is not None: |
| edge_attr = torch.tensor(row["edge_attr"], dtype=torch.float) |
| |
| data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr) |
| data.num_nodes = num_nodes |
| data_list.append(data) |
| |
| return data_list |
|
|
|
|
| class GIN(nn.Module): |
| """ |
| Graph Isomorphism Network (GIN) implementation. |
| Key insight: SUM aggregation (not mean/max) is injective for multisets, |
| making GIN as expressive as the 1-WL test for graph isomorphism. |
| """ |
| def __init__(self, in_channels, hidden_channels, num_classes, num_layers=5, dropout=0.5): |
| super().__init__() |
| self.num_layers = num_layers |
| self.dropout = dropout |
| |
| self.convs = nn.ModuleList() |
| self.batch_norms = nn.ModuleList() |
| |
| |
| |
| |
| for i in range(num_layers): |
| in_dim = in_channels if i == 0 else hidden_channels |
| mlp = nn.Sequential( |
| nn.Linear(in_dim, hidden_channels), |
| nn.ReLU(), |
| nn.Linear(hidden_channels, hidden_channels) |
| ) |
| self.convs.append(GINConv(mlp, train_eps=True)) |
| self.batch_norms.append(nn.BatchNorm1d(hidden_channels)) |
| |
| |
| |
| self.fc1 = nn.Linear(hidden_channels, hidden_channels) |
| self.fc2 = nn.Linear(hidden_channels, num_classes) |
| |
| def forward(self, x, edge_index, batch): |
| |
| for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)): |
| x = conv(x, edge_index) |
| x = bn(x) |
| x = F.relu(x) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| |
| |
| x = global_add_pool(x, batch) |
| |
| |
| x = F.relu(self.fc1(x)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| def train_epoch(model, loader, optimizer, device): |
| model.train() |
| total_loss = 0 |
| for data in loader: |
| data = data.to(device) |
| optimizer.zero_grad() |
| out = model(data.x, data.edge_index, data.batch) |
| loss = F.cross_entropy(out, data.y) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * data.num_graphs |
| return total_loss / len(loader.dataset) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| model.eval() |
| preds, labels = [], [] |
| for data in loader: |
| data = data.to(device) |
| out = model(data.x, data.edge_index, data.batch) |
| pred = out.argmax(dim=1) |
| preds.extend(pred.cpu().numpy()) |
| labels.extend(data.y.cpu().numpy()) |
| return accuracy_score(labels, preds) |
|
|
|
|
| def cross_validate_gin(data_list, num_classes, num_folds=10, num_epochs=200): |
| """10-fold stratified cross-validation as in original GIN paper.""" |
| labels = [d.y.item() for d in data_list] |
| skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42) |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| |
| fold_accs = [] |
| |
| for fold, (train_idx, test_idx) in enumerate(skf.split(data_list, labels)): |
| print(f"\n=== Fold {fold + 1}/{num_folds} ===") |
| |
| train_data = [data_list[i] for i in train_idx] |
| test_data = [data_list[i] for i in test_idx] |
| |
| train_loader = DataLoader(train_data, batch_size=32, shuffle=True) |
| test_loader = DataLoader(test_data, batch_size=32) |
| |
| in_channels = train_data[0].x.shape[1] |
| model = GIN(in_channels, hidden_channels=64, num_classes=num_classes).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.01) |
| |
| best_test_acc = 0 |
| for epoch in range(num_epochs): |
| loss = train_epoch(model, train_loader, optimizer, device) |
| if (epoch + 1) % 50 == 0: |
| train_acc = evaluate(model, train_loader, device) |
| test_acc = evaluate(model, test_loader, device) |
| print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}") |
| if test_acc > best_test_acc: |
| best_test_acc = test_acc |
| |
| |
| final_test_acc = evaluate(model, test_loader, device) |
| fold_accs.append(final_test_acc) |
| print(f" Fold {fold+1} best test accuracy: {best_test_acc:.4f}, final: {final_test_acc:.4f}") |
| |
| print(f"\n=== Results ===") |
| print(f"Mean accuracy: {np.mean(fold_accs)*100:.2f}% ± {np.std(fold_accs)*100:.2f}%") |
| print(f"All fold accuracies: {[f'{a*100:.1f}%' for a in fold_accs]}") |
| return fold_accs |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Graph Isomorphism Network (GIN) - MUTAG Classification") |
| print("=" * 60) |
| |
| |
| print("\n[1] Loading MUTAG dataset from HuggingFace...") |
| data_list = load_mutag_from_hf() |
| num_classes = len(set(d.y.item() for d in data_list)) |
| print(f" Loaded {len(data_list)} graphs, {num_classes} classes") |
| print(f" Node feature dim: {data_list[0].x.shape[1]}") |
| print(f" Average nodes per graph: {np.mean([d.num_nodes for d in data_list]):.1f}") |
| |
| |
| print("\n[2] Training GIN with 5-fold cross-validation...") |
| fold_accs = cross_validate_gin(data_list, num_classes, num_folds=5, num_epochs=150) |
| |
| print("\n" + "=" * 60) |
| print("KEY TAKEAWAYS - Why Isomorphism Matters for AI") |
| print("=" * 60) |
| print(""" |
| 1. GIN uses SUM aggregation (not mean/max) — the ONLY injective |
| multiset function. This makes it AS EXPRESSIVE as the 1-WL |
| (Weisfeiler-Lehman) graph isomorphism test. |
| |
| 2. Traditional GCN/GAT use mean/max pooling, which CANNOT distinguish |
| certain graph structures. GIN can. |
| |
| 3. This expressiveness is proven by theory and practice: |
| - GIN achieves SOTA on many graph classification benchmarks |
| - It generalizes better to unseen graph structures |
| - It learns true structural representations, not just node features |
| |
| 4. Applications: |
| - Molecular property prediction (drug discovery) |
| - Social network analysis |
| - Knowledge graph reasoning |
| - Program analysis & code similarity |
| - Anomaly detection in transaction graphs |
| |
| 5. The epsilon (ε) parameter in GIN controls how much the central |
| node's own features contribute — learnable or fixed (GIN-0). |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|