arkheb's picture
Upload gin_train.py with huggingface_hub
e998e94 verified
"""
Graph Isomorphism Network (GIN) training on MUTAG dataset.
Demonstrates isomorphism-aware learning for graph classification.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.loader import DataLoader
from datasets import load_dataset
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import random
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
def load_mutag_from_hf():
"""Load MUTAG dataset from HuggingFace and convert to PyG Data objects."""
ds = load_dataset("graphs-datasets/MUTAG", split="train")
from torch_geometric.data import Data
data_list = []
for row in ds:
edge_index = torch.tensor(row["edge_index"], dtype=torch.long)
x = torch.tensor(row["node_feat"], dtype=torch.float)
y = torch.tensor(row["y"], dtype=torch.long)
num_nodes = row["num_nodes"]
# Handle edge_attr if available
edge_attr = None
if row.get("edge_attr") is not None:
edge_attr = torch.tensor(row["edge_attr"], dtype=torch.float)
data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)
data.num_nodes = num_nodes
data_list.append(data)
return data_list
class GIN(nn.Module):
"""
Graph Isomorphism Network (GIN) implementation.
Key insight: SUM aggregation (not mean/max) is injective for multisets,
making GIN as expressive as the 1-WL test for graph isomorphism.
"""
def __init__(self, in_channels, hidden_channels, num_classes, num_layers=5, dropout=0.5):
super().__init__()
self.num_layers = num_layers
self.dropout = dropout
self.convs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
# Build MLPs for each GINConv layer
# GIN-0: train_eps=False (epsilon=0 fixed)
# GIN-epsilon: train_eps=True (learnable epsilon)
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
mlp = nn.Sequential(
nn.Linear(in_dim, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels)
)
self.convs.append(GINConv(mlp, train_eps=True))
self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
# Graph-level readout classifier
# Sum pooling is CRITICAL: it's the only injective multiset function
self.fc1 = nn.Linear(hidden_channels, hidden_channels)
self.fc2 = nn.Linear(hidden_channels, num_classes)
def forward(self, x, edge_index, batch):
# Node-level GIN layers
for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Graph-level readout: SUM aggregation across all nodes in each graph
x = global_add_pool(x, batch)
# Final classifier
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
return x
def train_epoch(model, loader, optimizer, device):
model.train()
total_loss = 0
for data in loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = F.cross_entropy(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
preds, labels = [], []
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
preds.extend(pred.cpu().numpy())
labels.extend(data.y.cpu().numpy())
return accuracy_score(labels, preds)
def cross_validate_gin(data_list, num_classes, num_folds=10, num_epochs=200):
"""10-fold stratified cross-validation as in original GIN paper."""
labels = [d.y.item() for d in data_list]
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
fold_accs = []
for fold, (train_idx, test_idx) in enumerate(skf.split(data_list, labels)):
print(f"\n=== Fold {fold + 1}/{num_folds} ===")
train_data = [data_list[i] for i in train_idx]
test_data = [data_list[i] for i in test_idx]
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)
in_channels = train_data[0].x.shape[1]
model = GIN(in_channels, hidden_channels=64, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_test_acc = 0
for epoch in range(num_epochs):
loss = train_epoch(model, train_loader, optimizer, device)
if (epoch + 1) % 50 == 0:
train_acc = evaluate(model, train_loader, device)
test_acc = evaluate(model, test_loader, device)
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
if test_acc > best_test_acc:
best_test_acc = test_acc
# Final evaluation
final_test_acc = evaluate(model, test_loader, device)
fold_accs.append(final_test_acc)
print(f" Fold {fold+1} best test accuracy: {best_test_acc:.4f}, final: {final_test_acc:.4f}")
print(f"\n=== Results ===")
print(f"Mean accuracy: {np.mean(fold_accs)*100:.2f}% ± {np.std(fold_accs)*100:.2f}%")
print(f"All fold accuracies: {[f'{a*100:.1f}%' for a in fold_accs]}")
return fold_accs
def main():
print("=" * 60)
print("Graph Isomorphism Network (GIN) - MUTAG Classification")
print("=" * 60)
# 1. Load MUTAG
print("\n[1] Loading MUTAG dataset from HuggingFace...")
data_list = load_mutag_from_hf()
num_classes = len(set(d.y.item() for d in data_list))
print(f" Loaded {len(data_list)} graphs, {num_classes} classes")
print(f" Node feature dim: {data_list[0].x.shape[1]}")
print(f" Average nodes per graph: {np.mean([d.num_nodes for d in data_list]):.1f}")
# 2. Train with cross-validation
print("\n[2] Training GIN with 5-fold cross-validation...")
fold_accs = cross_validate_gin(data_list, num_classes, num_folds=5, num_epochs=150)
print("\n" + "=" * 60)
print("KEY TAKEAWAYS - Why Isomorphism Matters for AI")
print("=" * 60)
print("""
1. GIN uses SUM aggregation (not mean/max) — the ONLY injective
multiset function. This makes it AS EXPRESSIVE as the 1-WL
(Weisfeiler-Lehman) graph isomorphism test.
2. Traditional GCN/GAT use mean/max pooling, which CANNOT distinguish
certain graph structures. GIN can.
3. This expressiveness is proven by theory and practice:
- GIN achieves SOTA on many graph classification benchmarks
- It generalizes better to unseen graph structures
- It learns true structural representations, not just node features
4. Applications:
- Molecular property prediction (drug discovery)
- Social network analysis
- Knowledge graph reasoning
- Program analysis & code similarity
- Anomaly detection in transaction graphs
5. The epsilon (ε) parameter in GIN controls how much the central
node's own features contribute — learnable or fixed (GIN-0).
""")
if __name__ == "__main__":
main()