Upload gin_train.py with huggingface_hub
Browse files- gin_train.py +215 -0
gin_train.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph Isomorphism Network (GIN) training on MUTAG dataset.
|
| 3 |
+
Demonstrates isomorphism-aware learning for graph classification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch_geometric.nn import GINConv, global_add_pool
|
| 10 |
+
from torch_geometric.loader import DataLoader
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sklearn.model_selection import StratifiedKFold
|
| 14 |
+
from sklearn.metrics import accuracy_score
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
# Set seeds for reproducibility
|
| 18 |
+
torch.manual_seed(42)
|
| 19 |
+
np.random.seed(42)
|
| 20 |
+
random.seed(42)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_mutag_from_hf():
|
| 24 |
+
"""Load MUTAG dataset from HuggingFace and convert to PyG Data objects."""
|
| 25 |
+
ds = load_dataset("graphs-datasets/MUTAG", split="train")
|
| 26 |
+
|
| 27 |
+
from torch_geometric.data import Data
|
| 28 |
+
|
| 29 |
+
data_list = []
|
| 30 |
+
for row in ds:
|
| 31 |
+
edge_index = torch.tensor(row["edge_index"], dtype=torch.long)
|
| 32 |
+
x = torch.tensor(row["node_feat"], dtype=torch.float)
|
| 33 |
+
y = torch.tensor(row["y"], dtype=torch.long)
|
| 34 |
+
num_nodes = row["num_nodes"]
|
| 35 |
+
|
| 36 |
+
# Handle edge_attr if available
|
| 37 |
+
edge_attr = None
|
| 38 |
+
if row.get("edge_attr") is not None:
|
| 39 |
+
edge_attr = torch.tensor(row["edge_attr"], dtype=torch.float)
|
| 40 |
+
|
| 41 |
+
data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)
|
| 42 |
+
data.num_nodes = num_nodes
|
| 43 |
+
data_list.append(data)
|
| 44 |
+
|
| 45 |
+
return data_list
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GIN(nn.Module):
|
| 49 |
+
"""
|
| 50 |
+
Graph Isomorphism Network (GIN) implementation.
|
| 51 |
+
Key insight: SUM aggregation (not mean/max) is injective for multisets,
|
| 52 |
+
making GIN as expressive as the 1-WL test for graph isomorphism.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, in_channels, hidden_channels, num_classes, num_layers=5, dropout=0.5):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.num_layers = num_layers
|
| 57 |
+
self.dropout = dropout
|
| 58 |
+
|
| 59 |
+
self.convs = nn.ModuleList()
|
| 60 |
+
self.batch_norms = nn.ModuleList()
|
| 61 |
+
|
| 62 |
+
# Build MLPs for each GINConv layer
|
| 63 |
+
# GIN-0: train_eps=False (epsilon=0 fixed)
|
| 64 |
+
# GIN-epsilon: train_eps=True (learnable epsilon)
|
| 65 |
+
for i in range(num_layers):
|
| 66 |
+
in_dim = in_channels if i == 0 else hidden_channels
|
| 67 |
+
mlp = nn.Sequential(
|
| 68 |
+
nn.Linear(in_dim, hidden_channels),
|
| 69 |
+
nn.ReLU(),
|
| 70 |
+
nn.Linear(hidden_channels, hidden_channels)
|
| 71 |
+
)
|
| 72 |
+
self.convs.append(GINConv(mlp, train_eps=True))
|
| 73 |
+
self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
|
| 74 |
+
|
| 75 |
+
# Graph-level readout classifier
|
| 76 |
+
# Sum pooling is CRITICAL: it's the only injective multiset function
|
| 77 |
+
self.fc1 = nn.Linear(hidden_channels, hidden_channels)
|
| 78 |
+
self.fc2 = nn.Linear(hidden_channels, num_classes)
|
| 79 |
+
|
| 80 |
+
def forward(self, x, edge_index, batch):
|
| 81 |
+
# Node-level GIN layers
|
| 82 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
|
| 83 |
+
x = conv(x, edge_index)
|
| 84 |
+
x = bn(x)
|
| 85 |
+
x = F.relu(x)
|
| 86 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 87 |
+
|
| 88 |
+
# Graph-level readout: SUM aggregation across all nodes in each graph
|
| 89 |
+
x = global_add_pool(x, batch)
|
| 90 |
+
|
| 91 |
+
# Final classifier
|
| 92 |
+
x = F.relu(self.fc1(x))
|
| 93 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 94 |
+
x = self.fc2(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def train_epoch(model, loader, optimizer, device):
|
| 99 |
+
model.train()
|
| 100 |
+
total_loss = 0
|
| 101 |
+
for data in loader:
|
| 102 |
+
data = data.to(device)
|
| 103 |
+
optimizer.zero_grad()
|
| 104 |
+
out = model(data.x, data.edge_index, data.batch)
|
| 105 |
+
loss = F.cross_entropy(out, data.y)
|
| 106 |
+
loss.backward()
|
| 107 |
+
optimizer.step()
|
| 108 |
+
total_loss += loss.item() * data.num_graphs
|
| 109 |
+
return total_loss / len(loader.dataset)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@torch.no_grad()
|
| 113 |
+
def evaluate(model, loader, device):
|
| 114 |
+
model.eval()
|
| 115 |
+
preds, labels = [], []
|
| 116 |
+
for data in loader:
|
| 117 |
+
data = data.to(device)
|
| 118 |
+
out = model(data.x, data.edge_index, data.batch)
|
| 119 |
+
pred = out.argmax(dim=1)
|
| 120 |
+
preds.extend(pred.cpu().numpy())
|
| 121 |
+
labels.extend(data.y.cpu().numpy())
|
| 122 |
+
return accuracy_score(labels, preds)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def cross_validate_gin(data_list, num_classes, num_folds=10, num_epochs=200):
|
| 126 |
+
"""10-fold stratified cross-validation as in original GIN paper."""
|
| 127 |
+
labels = [d.y.item() for d in data_list]
|
| 128 |
+
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
|
| 129 |
+
|
| 130 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 131 |
+
print(f"Device: {device}")
|
| 132 |
+
|
| 133 |
+
fold_accs = []
|
| 134 |
+
|
| 135 |
+
for fold, (train_idx, test_idx) in enumerate(skf.split(data_list, labels)):
|
| 136 |
+
print(f"\n=== Fold {fold + 1}/{num_folds} ===")
|
| 137 |
+
|
| 138 |
+
train_data = [data_list[i] for i in train_idx]
|
| 139 |
+
test_data = [data_list[i] for i in test_idx]
|
| 140 |
+
|
| 141 |
+
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
|
| 142 |
+
test_loader = DataLoader(test_data, batch_size=32)
|
| 143 |
+
|
| 144 |
+
in_channels = train_data[0].x.shape[1]
|
| 145 |
+
model = GIN(in_channels, hidden_channels=64, num_classes=num_classes).to(device)
|
| 146 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
| 147 |
+
|
| 148 |
+
best_test_acc = 0
|
| 149 |
+
for epoch in range(num_epochs):
|
| 150 |
+
loss = train_epoch(model, train_loader, optimizer, device)
|
| 151 |
+
if (epoch + 1) % 50 == 0:
|
| 152 |
+
train_acc = evaluate(model, train_loader, device)
|
| 153 |
+
test_acc = evaluate(model, test_loader, device)
|
| 154 |
+
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
|
| 155 |
+
if test_acc > best_test_acc:
|
| 156 |
+
best_test_acc = test_acc
|
| 157 |
+
|
| 158 |
+
# Final evaluation
|
| 159 |
+
final_test_acc = evaluate(model, test_loader, device)
|
| 160 |
+
fold_accs.append(final_test_acc)
|
| 161 |
+
print(f" Fold {fold+1} best test accuracy: {best_test_acc:.4f}, final: {final_test_acc:.4f}")
|
| 162 |
+
|
| 163 |
+
print(f"\n=== Results ===")
|
| 164 |
+
print(f"Mean accuracy: {np.mean(fold_accs)*100:.2f}% ± {np.std(fold_accs)*100:.2f}%")
|
| 165 |
+
print(f"All fold accuracies: {[f'{a*100:.1f}%' for a in fold_accs]}")
|
| 166 |
+
return fold_accs
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def main():
|
| 170 |
+
print("=" * 60)
|
| 171 |
+
print("Graph Isomorphism Network (GIN) - MUTAG Classification")
|
| 172 |
+
print("=" * 60)
|
| 173 |
+
|
| 174 |
+
# 1. Load MUTAG
|
| 175 |
+
print("\n[1] Loading MUTAG dataset from HuggingFace...")
|
| 176 |
+
data_list = load_mutag_from_hf()
|
| 177 |
+
num_classes = len(set(d.y.item() for d in data_list))
|
| 178 |
+
print(f" Loaded {len(data_list)} graphs, {num_classes} classes")
|
| 179 |
+
print(f" Node feature dim: {data_list[0].x.shape[1]}")
|
| 180 |
+
print(f" Average nodes per graph: {np.mean([d.num_nodes for d in data_list]):.1f}")
|
| 181 |
+
|
| 182 |
+
# 2. Train with cross-validation
|
| 183 |
+
print("\n[2] Training GIN with 5-fold cross-validation...")
|
| 184 |
+
fold_accs = cross_validate_gin(data_list, num_classes, num_folds=5, num_epochs=150)
|
| 185 |
+
|
| 186 |
+
print("\n" + "=" * 60)
|
| 187 |
+
print("KEY TAKEAWAYS - Why Isomorphism Matters for AI")
|
| 188 |
+
print("=" * 60)
|
| 189 |
+
print("""
|
| 190 |
+
1. GIN uses SUM aggregation (not mean/max) — the ONLY injective
|
| 191 |
+
multiset function. This makes it AS EXPRESSIVE as the 1-WL
|
| 192 |
+
(Weisfeiler-Lehman) graph isomorphism test.
|
| 193 |
+
|
| 194 |
+
2. Traditional GCN/GAT use mean/max pooling, which CANNOT distinguish
|
| 195 |
+
certain graph structures. GIN can.
|
| 196 |
+
|
| 197 |
+
3. This expressiveness is proven by theory and practice:
|
| 198 |
+
- GIN achieves SOTA on many graph classification benchmarks
|
| 199 |
+
- It generalizes better to unseen graph structures
|
| 200 |
+
- It learns true structural representations, not just node features
|
| 201 |
+
|
| 202 |
+
4. Applications:
|
| 203 |
+
- Molecular property prediction (drug discovery)
|
| 204 |
+
- Social network analysis
|
| 205 |
+
- Knowledge graph reasoning
|
| 206 |
+
- Program analysis & code similarity
|
| 207 |
+
- Anomaly detection in transaction graphs
|
| 208 |
+
|
| 209 |
+
5. The epsilon (ε) parameter in GIN controls how much the central
|
| 210 |
+
node's own features contribute — learnable or fixed (GIN-0).
|
| 211 |
+
""")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
main()
|