""" Graph Isomorphism Network (GIN) training on MUTAG dataset. Demonstrates isomorphism-aware learning for graph classification. """ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GINConv, global_add_pool from torch_geometric.loader import DataLoader from datasets import load_dataset import numpy as np from sklearn.model_selection import StratifiedKFold from sklearn.metrics import accuracy_score import random # Set seeds for reproducibility torch.manual_seed(42) np.random.seed(42) random.seed(42) def load_mutag_from_hf(): """Load MUTAG dataset from HuggingFace and convert to PyG Data objects.""" ds = load_dataset("graphs-datasets/MUTAG", split="train") from torch_geometric.data import Data data_list = [] for row in ds: edge_index = torch.tensor(row["edge_index"], dtype=torch.long) x = torch.tensor(row["node_feat"], dtype=torch.float) y = torch.tensor(row["y"], dtype=torch.long) num_nodes = row["num_nodes"] # Handle edge_attr if available edge_attr = None if row.get("edge_attr") is not None: edge_attr = torch.tensor(row["edge_attr"], dtype=torch.float) data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr) data.num_nodes = num_nodes data_list.append(data) return data_list class GIN(nn.Module): """ Graph Isomorphism Network (GIN) implementation. Key insight: SUM aggregation (not mean/max) is injective for multisets, making GIN as expressive as the 1-WL test for graph isomorphism. """ def __init__(self, in_channels, hidden_channels, num_classes, num_layers=5, dropout=0.5): super().__init__() self.num_layers = num_layers self.dropout = dropout self.convs = nn.ModuleList() self.batch_norms = nn.ModuleList() # Build MLPs for each GINConv layer # GIN-0: train_eps=False (epsilon=0 fixed) # GIN-epsilon: train_eps=True (learnable epsilon) for i in range(num_layers): in_dim = in_channels if i == 0 else hidden_channels mlp = nn.Sequential( nn.Linear(in_dim, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, hidden_channels) ) self.convs.append(GINConv(mlp, train_eps=True)) self.batch_norms.append(nn.BatchNorm1d(hidden_channels)) # Graph-level readout classifier # Sum pooling is CRITICAL: it's the only injective multiset function self.fc1 = nn.Linear(hidden_channels, hidden_channels) self.fc2 = nn.Linear(hidden_channels, num_classes) def forward(self, x, edge_index, batch): # Node-level GIN layers for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)): x = conv(x, edge_index) x = bn(x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) # Graph-level readout: SUM aggregation across all nodes in each graph x = global_add_pool(x, batch) # Final classifier x = F.relu(self.fc1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.fc2(x) return x def train_epoch(model, loader, optimizer, device): model.train() total_loss = 0 for data in loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) loss = F.cross_entropy(out, data.y) loss.backward() optimizer.step() total_loss += loss.item() * data.num_graphs return total_loss / len(loader.dataset) @torch.no_grad() def evaluate(model, loader, device): model.eval() preds, labels = [], [] for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.batch) pred = out.argmax(dim=1) preds.extend(pred.cpu().numpy()) labels.extend(data.y.cpu().numpy()) return accuracy_score(labels, preds) def cross_validate_gin(data_list, num_classes, num_folds=10, num_epochs=200): """10-fold stratified cross-validation as in original GIN paper.""" labels = [d.y.item() for d in data_list] skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") fold_accs = [] for fold, (train_idx, test_idx) in enumerate(skf.split(data_list, labels)): print(f"\n=== Fold {fold + 1}/{num_folds} ===") train_data = [data_list[i] for i in train_idx] test_data = [data_list[i] for i in test_idx] train_loader = DataLoader(train_data, batch_size=32, shuffle=True) test_loader = DataLoader(test_data, batch_size=32) in_channels = train_data[0].x.shape[1] model = GIN(in_channels, hidden_channels=64, num_classes=num_classes).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) best_test_acc = 0 for epoch in range(num_epochs): loss = train_epoch(model, train_loader, optimizer, device) if (epoch + 1) % 50 == 0: train_acc = evaluate(model, train_loader, device) test_acc = evaluate(model, test_loader, device) print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}") if test_acc > best_test_acc: best_test_acc = test_acc # Final evaluation final_test_acc = evaluate(model, test_loader, device) fold_accs.append(final_test_acc) print(f" Fold {fold+1} best test accuracy: {best_test_acc:.4f}, final: {final_test_acc:.4f}") print(f"\n=== Results ===") print(f"Mean accuracy: {np.mean(fold_accs)*100:.2f}% ± {np.std(fold_accs)*100:.2f}%") print(f"All fold accuracies: {[f'{a*100:.1f}%' for a in fold_accs]}") return fold_accs def main(): print("=" * 60) print("Graph Isomorphism Network (GIN) - MUTAG Classification") print("=" * 60) # 1. Load MUTAG print("\n[1] Loading MUTAG dataset from HuggingFace...") data_list = load_mutag_from_hf() num_classes = len(set(d.y.item() for d in data_list)) print(f" Loaded {len(data_list)} graphs, {num_classes} classes") print(f" Node feature dim: {data_list[0].x.shape[1]}") print(f" Average nodes per graph: {np.mean([d.num_nodes for d in data_list]):.1f}") # 2. Train with cross-validation print("\n[2] Training GIN with 5-fold cross-validation...") fold_accs = cross_validate_gin(data_list, num_classes, num_folds=5, num_epochs=150) print("\n" + "=" * 60) print("KEY TAKEAWAYS - Why Isomorphism Matters for AI") print("=" * 60) print(""" 1. GIN uses SUM aggregation (not mean/max) — the ONLY injective multiset function. This makes it AS EXPRESSIVE as the 1-WL (Weisfeiler-Lehman) graph isomorphism test. 2. Traditional GCN/GAT use mean/max pooling, which CANNOT distinguish certain graph structures. GIN can. 3. This expressiveness is proven by theory and practice: - GIN achieves SOTA on many graph classification benchmarks - It generalizes better to unseen graph structures - It learns true structural representations, not just node features 4. Applications: - Molecular property prediction (drug discovery) - Social network analysis - Knowledge graph reasoning - Program analysis & code similarity - Anomaly detection in transaction graphs 5. The epsilon (ε) parameter in GIN controls how much the central node's own features contribute — learnable or fixed (GIN-0). """) if __name__ == "__main__": main()