arkheb commited on
Commit
e998e94
·
verified ·
1 Parent(s): 0102cf6

Upload gin_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gin_train.py +215 -0
gin_train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph Isomorphism Network (GIN) training on MUTAG dataset.
3
+ Demonstrates isomorphism-aware learning for graph classification.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch_geometric.nn import GINConv, global_add_pool
10
+ from torch_geometric.loader import DataLoader
11
+ from datasets import load_dataset
12
+ import numpy as np
13
+ from sklearn.model_selection import StratifiedKFold
14
+ from sklearn.metrics import accuracy_score
15
+ import random
16
+
17
+ # Set seeds for reproducibility
18
+ torch.manual_seed(42)
19
+ np.random.seed(42)
20
+ random.seed(42)
21
+
22
+
23
+ def load_mutag_from_hf():
24
+ """Load MUTAG dataset from HuggingFace and convert to PyG Data objects."""
25
+ ds = load_dataset("graphs-datasets/MUTAG", split="train")
26
+
27
+ from torch_geometric.data import Data
28
+
29
+ data_list = []
30
+ for row in ds:
31
+ edge_index = torch.tensor(row["edge_index"], dtype=torch.long)
32
+ x = torch.tensor(row["node_feat"], dtype=torch.float)
33
+ y = torch.tensor(row["y"], dtype=torch.long)
34
+ num_nodes = row["num_nodes"]
35
+
36
+ # Handle edge_attr if available
37
+ edge_attr = None
38
+ if row.get("edge_attr") is not None:
39
+ edge_attr = torch.tensor(row["edge_attr"], dtype=torch.float)
40
+
41
+ data = Data(x=x, edge_index=edge_index, y=y, edge_attr=edge_attr)
42
+ data.num_nodes = num_nodes
43
+ data_list.append(data)
44
+
45
+ return data_list
46
+
47
+
48
+ class GIN(nn.Module):
49
+ """
50
+ Graph Isomorphism Network (GIN) implementation.
51
+ Key insight: SUM aggregation (not mean/max) is injective for multisets,
52
+ making GIN as expressive as the 1-WL test for graph isomorphism.
53
+ """
54
+ def __init__(self, in_channels, hidden_channels, num_classes, num_layers=5, dropout=0.5):
55
+ super().__init__()
56
+ self.num_layers = num_layers
57
+ self.dropout = dropout
58
+
59
+ self.convs = nn.ModuleList()
60
+ self.batch_norms = nn.ModuleList()
61
+
62
+ # Build MLPs for each GINConv layer
63
+ # GIN-0: train_eps=False (epsilon=0 fixed)
64
+ # GIN-epsilon: train_eps=True (learnable epsilon)
65
+ for i in range(num_layers):
66
+ in_dim = in_channels if i == 0 else hidden_channels
67
+ mlp = nn.Sequential(
68
+ nn.Linear(in_dim, hidden_channels),
69
+ nn.ReLU(),
70
+ nn.Linear(hidden_channels, hidden_channels)
71
+ )
72
+ self.convs.append(GINConv(mlp, train_eps=True))
73
+ self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
74
+
75
+ # Graph-level readout classifier
76
+ # Sum pooling is CRITICAL: it's the only injective multiset function
77
+ self.fc1 = nn.Linear(hidden_channels, hidden_channels)
78
+ self.fc2 = nn.Linear(hidden_channels, num_classes)
79
+
80
+ def forward(self, x, edge_index, batch):
81
+ # Node-level GIN layers
82
+ for i, (conv, bn) in enumerate(zip(self.convs, self.batch_norms)):
83
+ x = conv(x, edge_index)
84
+ x = bn(x)
85
+ x = F.relu(x)
86
+ x = F.dropout(x, p=self.dropout, training=self.training)
87
+
88
+ # Graph-level readout: SUM aggregation across all nodes in each graph
89
+ x = global_add_pool(x, batch)
90
+
91
+ # Final classifier
92
+ x = F.relu(self.fc1(x))
93
+ x = F.dropout(x, p=self.dropout, training=self.training)
94
+ x = self.fc2(x)
95
+ return x
96
+
97
+
98
+ def train_epoch(model, loader, optimizer, device):
99
+ model.train()
100
+ total_loss = 0
101
+ for data in loader:
102
+ data = data.to(device)
103
+ optimizer.zero_grad()
104
+ out = model(data.x, data.edge_index, data.batch)
105
+ loss = F.cross_entropy(out, data.y)
106
+ loss.backward()
107
+ optimizer.step()
108
+ total_loss += loss.item() * data.num_graphs
109
+ return total_loss / len(loader.dataset)
110
+
111
+
112
+ @torch.no_grad()
113
+ def evaluate(model, loader, device):
114
+ model.eval()
115
+ preds, labels = [], []
116
+ for data in loader:
117
+ data = data.to(device)
118
+ out = model(data.x, data.edge_index, data.batch)
119
+ pred = out.argmax(dim=1)
120
+ preds.extend(pred.cpu().numpy())
121
+ labels.extend(data.y.cpu().numpy())
122
+ return accuracy_score(labels, preds)
123
+
124
+
125
+ def cross_validate_gin(data_list, num_classes, num_folds=10, num_epochs=200):
126
+ """10-fold stratified cross-validation as in original GIN paper."""
127
+ labels = [d.y.item() for d in data_list]
128
+ skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
129
+
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ print(f"Device: {device}")
132
+
133
+ fold_accs = []
134
+
135
+ for fold, (train_idx, test_idx) in enumerate(skf.split(data_list, labels)):
136
+ print(f"\n=== Fold {fold + 1}/{num_folds} ===")
137
+
138
+ train_data = [data_list[i] for i in train_idx]
139
+ test_data = [data_list[i] for i in test_idx]
140
+
141
+ train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
142
+ test_loader = DataLoader(test_data, batch_size=32)
143
+
144
+ in_channels = train_data[0].x.shape[1]
145
+ model = GIN(in_channels, hidden_channels=64, num_classes=num_classes).to(device)
146
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
147
+
148
+ best_test_acc = 0
149
+ for epoch in range(num_epochs):
150
+ loss = train_epoch(model, train_loader, optimizer, device)
151
+ if (epoch + 1) % 50 == 0:
152
+ train_acc = evaluate(model, train_loader, device)
153
+ test_acc = evaluate(model, test_loader, device)
154
+ print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.4f}, test_acc={test_acc:.4f}")
155
+ if test_acc > best_test_acc:
156
+ best_test_acc = test_acc
157
+
158
+ # Final evaluation
159
+ final_test_acc = evaluate(model, test_loader, device)
160
+ fold_accs.append(final_test_acc)
161
+ print(f" Fold {fold+1} best test accuracy: {best_test_acc:.4f}, final: {final_test_acc:.4f}")
162
+
163
+ print(f"\n=== Results ===")
164
+ print(f"Mean accuracy: {np.mean(fold_accs)*100:.2f}% ± {np.std(fold_accs)*100:.2f}%")
165
+ print(f"All fold accuracies: {[f'{a*100:.1f}%' for a in fold_accs]}")
166
+ return fold_accs
167
+
168
+
169
+ def main():
170
+ print("=" * 60)
171
+ print("Graph Isomorphism Network (GIN) - MUTAG Classification")
172
+ print("=" * 60)
173
+
174
+ # 1. Load MUTAG
175
+ print("\n[1] Loading MUTAG dataset from HuggingFace...")
176
+ data_list = load_mutag_from_hf()
177
+ num_classes = len(set(d.y.item() for d in data_list))
178
+ print(f" Loaded {len(data_list)} graphs, {num_classes} classes")
179
+ print(f" Node feature dim: {data_list[0].x.shape[1]}")
180
+ print(f" Average nodes per graph: {np.mean([d.num_nodes for d in data_list]):.1f}")
181
+
182
+ # 2. Train with cross-validation
183
+ print("\n[2] Training GIN with 5-fold cross-validation...")
184
+ fold_accs = cross_validate_gin(data_list, num_classes, num_folds=5, num_epochs=150)
185
+
186
+ print("\n" + "=" * 60)
187
+ print("KEY TAKEAWAYS - Why Isomorphism Matters for AI")
188
+ print("=" * 60)
189
+ print("""
190
+ 1. GIN uses SUM aggregation (not mean/max) — the ONLY injective
191
+ multiset function. This makes it AS EXPRESSIVE as the 1-WL
192
+ (Weisfeiler-Lehman) graph isomorphism test.
193
+
194
+ 2. Traditional GCN/GAT use mean/max pooling, which CANNOT distinguish
195
+ certain graph structures. GIN can.
196
+
197
+ 3. This expressiveness is proven by theory and practice:
198
+ - GIN achieves SOTA on many graph classification benchmarks
199
+ - It generalizes better to unseen graph structures
200
+ - It learns true structural representations, not just node features
201
+
202
+ 4. Applications:
203
+ - Molecular property prediction (drug discovery)
204
+ - Social network analysis
205
+ - Knowledge graph reasoning
206
+ - Program analysis & code similarity
207
+ - Anomaly detection in transaction graphs
208
+
209
+ 5. The epsilon (ε) parameter in GIN controls how much the central
210
+ node's own features contribute — learnable or fixed (GIN-0).
211
+ """)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()