clique / src /old /5_compare_gcn_pools_proteins_nomask2_fixed.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
# 5_compare_gcn_pools_reddit_nomask.py (PROTEINS graph classification)
# Compares:
# - Plain GCN
# - GCN + LRMC hard clustering (contract via S^T A S, S^T X)
# - GCN + DiffPool (1 layer, dense per-graph)
# - GCN + gPool (TopKPooling)
#
# Also provides utilities to export PROTEINS into per-graph edgelists (1-based)
# and to consume LRMC seeds dumped per graph from Java.
#
# Usage examples:
# # 1) Export per-graph edge lists (for Java LRMC seeder)
# python 5_compare_gcn_pools_reddit_nomask.py --export_edgelists --out_dir ./proteins_edgelists
#
# # 2) Train/eval with precomputed LRMC seeds (JSON files in seeds_dir)
# python 5_compare_gcn_pools_reddit_nomask.py --seeds_dir ./proteins_seeds --pool_ratio 0.5 --epochs 200
#
# # 3) Train/eval without LRMC (still runs Plain/DiffPool/gPool)
# python 5_compare_gcn_pools_reddit_nomask.py --pool_ratio 0.5 --epochs 200
#
import argparse, os, json, math, random
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling
from torch_geometric.utils import to_dense_adj, add_self_loops, remove_self_loops, coalesce, subgraph
Device = torch.device
# -----------------------------
# Utilities
# -----------------------------
class DatasetWithGID(torch.utils.data.Dataset):
"""Wraps a PyG dataset and attaches a per-item global id as a tensor
so that Batch will collate it into data.gid (shape [num_graphs, 1])."""
def __init__(self, base):
self.base = base
# Expose common attributes for convenience
for attr in ("num_classes", "num_features"):
if hasattr(base, attr):
setattr(self, attr, getattr(base, attr))
def __len__(self): return len(self.base)
def __getitem__(self, idx):
data = self.base[idx]
data.gid = torch.tensor([idx], dtype=torch.long)
return data
def set_seed(seed: int):
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def export_proteins_edgelists(root: Path, out_dir: Path):
ds = TUDataset(root=str(root), name='PROTEINS')
out_dir.mkdir(parents=True, exist_ok=True)
for i, data in enumerate(ds):
n = int(data.num_nodes)
ei = data.edge_index
# 1-based ids per graph
with (out_dir / f"graph_{i:06d}.txt").open('w') as f:
for u, v in ei.t().tolist():
f.write(f"{u+1} {v+1}\n")
print(f"[export] Wrote {len(ds)} edge lists to {out_dir}")
def load_lrmc_seeds_dir(seeds_dir: Optional[Path]) -> Optional[List[List[List[int]]]]:
"""Return list indexed by graph idx -> list of clusters -> list of node indices.
Expect files named graph_000000.json OR i.json with structure {"clusters":[{"members":[...]}, ...]}
"""
if seeds_dir is None: return None
by_graph: Dict[int, List[List[int]]] = {}
for p in sorted(seeds_dir.glob("*.json")):
stem = p.stem
try:
gi = int(stem.split('_')[-1]) if stem.startswith("graph_") else int(stem)
except:
continue
obj = json.loads(p.read_text())
clusters = []
for c in obj.get("clusters", []):
mem = c.get("members") or c.get("nodes") or []
clusters.append([int(x) for x in mem])
by_graph[gi] = clusters
if not by_graph:
print(f"[warn] no seed jsons found in {seeds_dir}")
return None
# Convert to list ordered by graph idx with possible Nones for missing
max_i = max(by_graph.keys())
out: List[Optional[List[List[int]]]] = [None]*(max_i+1)
for i, clusters in by_graph.items():
out[i] = clusters
return out
# -----------------------------
# Models
# -----------------------------
class PlainGCN(nn.Module):
def __init__(self, in_dim: int, hidden: int, num_classes: int):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.conv2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
self.lin = nn.Linear(hidden, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return self.lin(x)
class gPoolNet(nn.Module):
def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden)
self.pool1 = TopKPooling(hidden, ratio=ratio)
self.conv2 = GCNConv(hidden, hidden)
self.lin = nn.Linear(hidden, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return self.lin(x)
class DiffPoolOneShot(nn.Module):
# One DiffPool layer per graph (dense inside each graph; PROTEINS graphs are small).
def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float):
super().__init__()
self.gnn_embed = GCNConv(in_dim, hidden)
self.assign = nn.Linear(hidden,
max(1, int(round(ratio * 50)))) # fallback K upper bound; reset per-graph at runtime
self.post_conv = GCNConv(hidden, hidden)
self.lin = nn.Linear(hidden, num_classes)
self.ratio = ratio
def forward(self, data):
# We will process graphs in the batch independently to avoid padding logic.
x, edge_index, batch = data.x, data.edge_index, data.batch
out = []
for gi in batch.unique().tolist():
mask = (batch == gi)
xi = x[mask]
# select intra-graph edges and reindex to local [0..n_i-1]
edges_mask = (batch[edge_index[0]] == gi) & (batch[edge_index[1]] == gi)
eidx_g = edge_index[:, edges_mask]
idx = torch.where(mask)[0]
n_i = idx.numel()
local = torch.full((batch.size(0),), -1, dtype=torch.long, device=batch.device)
local[idx] = torch.arange(n_i, device=batch.device)
eidx_i = local[eidx_g]
# embed nodes
zi = F.relu(self.gnn_embed(xi, eidx_i))
K_i = max(1, int(round(self.ratio * n_i)))
Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) # [n_i, K_i]
Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] # [n_i, n_i]
Zi = Si.t() @ zi # [K_i, hidden]
Ai_pooled = Si.t() @ Ai @ Si # [K_i, K_i]
# Process each graph in the batch independently with relabeled node indices.
x, edge_index, batch = data.x, data.edge_index, data.batch
# global graph ids (dataset indices)
gids = data.gid.view(-1)
out = []
for gi_local, gi_global in enumerate(gids.tolist()):
subset_idx = torch.where(batch == gi_local)[0]
if subset_idx.numel() == 0:
continue
# Relabel edge_index to [0..n_i-1] for this subgraph
eidx_i, _ = subgraph(subset_idx, edge_index, relabel_nodes=True, num_nodes=x.size(0))
xi = x[subset_idx]
# embed nodes
zi = F.relu(self.gnn_embed(xi, eidx_i))
n_i = zi.size(0)
K_i = max(1, int(round(self.ratio * n_i)))
Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) # [n_i, K_i]
Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] # [n_i, n_i]
Zi = Si.t() @ zi # [K_i, hidden]
Ai_pooled = Si.t() @ Ai @ Si # [K_i, K_i]
# Convert Ai_pooled to sparse for GCNConv
rows, cols = (Ai_pooled > 0).nonzero(as_tuple=True)
if rows.numel() == 0:
# Fallback to identity if empty
edge_index_coarse = torch.stack([torch.arange(K_i, device=Zi.device), torch.arange(K_i, device=Zi.device)])
edge_weight = None
else:
edge_index_coarse = torch.stack([rows, cols], dim=0)
edge_weight = Ai_pooled[rows, cols]
Zi2 = F.relu(self.post_conv(Zi, edge_index_coarse, edge_weight))
out.append(Zi2.mean(dim=0, keepdim=True))
Xg = torch.cat(out, dim=0) if out else torch.zeros(0, self.lin.in_features, device=x.device)
return self.lin(Xg)
class LRMCPoolNet(nn.Module):
"""Hard cluster assignment from LRMC seeds per graph; falls back to no pooling if seeds missing."""
def __init__(self, in_dim: int, hidden: int, num_classes: int, pool_ratio: float,
seeds_by_graph: Optional[List[Optional[List[List[int]]]]] = None):
super().__init__()
self.conv1 = GCNConv(in_dim, hidden)
self.conv2 = GCNConv(hidden, hidden)
self.post_conv = GCNConv(hidden, hidden)
self.lin = nn.Linear(hidden, num_classes)
self.pool_ratio = pool_ratio
self.seeds_by_graph = seeds_by_graph
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
z = F.relu(self.conv1(x, edge_index))
z = F.relu(self.conv2(z, edge_index))
# global ids for graphs in this batch (shape [num_graphs])
gids = data.gid.view(-1) # one per graph in the batch
pooled = []
for gi_local, gi_global in enumerate(gids.tolist()):
mask = (batch == gi_local)
zi = z[mask]
# build local 0..n_i-1 index map and reindex edges
idx_nodes = torch.where(mask)[0]
local = -torch.ones(z.size(0), dtype=torch.long, device=z.device)
local[idx_nodes] = torch.arange(idx_nodes.size(0), device=z.device)
edge_mask = (batch[edge_index[0]] == gi_local) & (batch[edge_index[1]] == gi_local)
eidx_i = local[edge_index[:, edge_mask]]
n_i = zi.size(0)
clusters = None
if self.seeds_by_graph is not None and 0 <= gi_global < len(self.seeds_by_graph):
clusters = self.seeds_by_graph[gi_global]
if not hasattr(self, "_dbg"):
self._dbg = True
print(f"[LRMC] used_seeds={clusters is not None} K={len(clusters) if clusters else 0} n={n_i}")
# Validate clusters: must be a list of int lists within node range
valid = True
if clusters:
cleaned = []
for mem in clusters:
if not mem:
continue
mt = torch.tensor(mem, dtype=torch.long, device=z.device)
# Try 1-based → 0-based if it fits
if mt.min().item() >= 1 and mt.max().item() <= n_i and (mt == 0).sum().item() == 0:
mt = mt - 1
# Now check bounds
if (mt < 0).any() or (mt >= n_i).any():
valid = False
break
cleaned.append(mt.tolist())
clusters = cleaned if valid else None
if not clusters:
# Fallback: identity pooling (no change)
pooled.append(zi.mean(dim=0, keepdim=True))
continue
# target K and simple coarsen
targetK = max(1, int(round(self.pool_ratio * n_i)))
clusters = [sorted(set(c)) for c in clusters if len(c) >= 1]
if len(clusters) > targetK:
clusters = sorted(clusters, key=len, reverse=True)[:targetK]
# Build cluster id vector (auto-detect 1-based vs 0-based) and clamp
cid = -torch.ones(n_i, dtype=torch.long, device=z.device)
bad = False
for k, mem in enumerate(clusters):
mem_t = torch.tensor(mem, dtype=torch.long, device=z.device)
if mem_t.numel() > 0 and mem_t.min().item() >= 1 and mem_t.max().item() <= n_i and (
mem_t == 0).sum().item() == 0:
mem_t = mem_t - 1
# if out-of-range even after adjustment, bail out for this graph
if mem_t.numel() and (mem_t.min().item() < 0 or mem_t.max().item() >= n_i):
bad = True
break
cid[mem_t] = k
if bad:
# fallback if seeds look inconsistent with this graph
if not hasattr(self, "_bad_warned"):
self._bad_warned = set()
if gi_global not in self._bad_warned:
print(
f"[LRMC] Warning: seed indices out of range for graph gid={gi_global} (n={n_i}); falling back to mean pooling for this graph.")
self._bad_warned.add(gi_global)
pooled.append(zi.mean(dim=0, keepdim=True))
continue
# give stragglers their own clusters up to targetK
for u in torch.where(cid < 0)[0].tolist():
clusters.append([int(u)])
cid[u] = len(clusters) - 1
if len(clusters) >= targetK:
break
if (cid < 0).any(): # any still unassigned → dump to last cluster
cid[cid < 0] = len(clusters) - 1
# mean over clusters → one vector per graph
Zi = torch.zeros(len(clusters), z.size(1), device=z.device)
Zi.index_add_(0, cid, zi)
counts = torch.bincount(cid, minlength=len(clusters)).clamp(min=1).view(-1, 1).to(Zi.dtype)
Zi = Zi / counts
# Build coarse edge_index via cluster assignment: e_coarse = coalesce([cid[u], cid[v]])
e_src = cid[eidx_i[0]]
e_dst = cid[eidx_i[1]]
e_coarse, _ = coalesce(torch.stack([e_src, e_dst], dim=0), None, len(clusters), len(clusters))
# Post-pool convolution on (Zi, e_coarse)
Zi2 = F.relu(self.post_conv(Zi, e_coarse))
pooled.append(Zi2.mean(dim=0, keepdim=True))
Xg = torch.cat(pooled, dim=0)
return self.lin(Xg)
# -----------------------------
# Training & eval
# -----------------------------
def train_epoch(model, loader, device, opt):
model.train()
total = 0.0
for data in loader:
data = data.to(device)
opt.zero_grad()
out = model(data)
loss = F.cross_entropy(out, data.y)
loss.backward()
opt.step()
total += loss.detach().item() * data.num_graphs
return total / len(loader.dataset)
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
correct = 0
total = 0
for data in loader:
data = data.to(device)
pred = model(data).argmax(dim=-1)
correct += int((pred == data.y).sum())
total += data.num_graphs
return correct / total
# -----------------------------
# Main
# -----------------------------
def main(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')
# Dataset
base = TUDataset(root=str(args.data_root), name='PROTEINS')
ds = DatasetWithGID(base)
# attach global ids so the batch knows which dataset graph each sample is
for i in range(len(ds)):
ds[i].gid = torch.tensor([i], dtype=torch.long)
num_classes = base.num_classes
in_dim = base.num_features if base.num_features and base.num_features > 0 else 1
# Fallback: if no features, use degree as a single feature
if in_dim == 0:
from torch_geometric.utils import degree
for data in base:
deg = degree(data.edge_index[0], num_nodes=data.num_nodes).view(-1,1)
data.x = deg
in_dim = 1
# Splits
N = len(ds)
n_train = int(0.8 * N); n_val = int(0.1 * N); n_test = N - n_train - n_val
train_set, val_set, test_set = random_split(ds, [n_train, n_val, n_test],
generator=torch.Generator().manual_seed(args.seed))
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)
seeds_by_graph = load_lrmc_seeds_dir(Path(args.seeds_dir)) if args.seeds_dir else None
# Models
hidden = args.hidden
plain = PlainGCN(in_dim, hidden, num_classes).to(device)
lrmc = LRMCPoolNet(in_dim, hidden, num_classes, pool_ratio=args.pool_ratio,
seeds_by_graph=seeds_by_graph).to(device)
gpool = gPoolNet(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device)
diffp = DiffPoolOneShot(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device)
# Train each model separately
for name, model in [('PlainGCN', plain), ('L-RMC', lrmc), ('gPool', gpool), ('DiffPool', diffp)]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
best_val = 0.0; best_state = None
for epoch in range(1, args.epochs+1):
loss = train_epoch(model, train_loader, device, opt)
acc_val = evaluate(model, val_loader, device)
if acc_val >= best_val:
best_val = acc_val; best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
if epoch % 20 == 0 or epoch == args.epochs:
print(f"{name} epoch {epoch}: loss={loss:.4f}, val_acc={acc_val:.3f}")
if best_state is not None: model.load_state_dict(best_state, strict=False)
acc_test = evaluate(model, test_loader, device)
print(f"{name:8s} test_acc={acc_test:.3f}")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--data_root", type=str, default="./data", help="Root dir for TUDataset(PROTEINS)")
p.add_argument("--export_edgelists", action="store_true", help="Export per-graph edgelists for Java LRMC seeder")
p.add_argument("--out_dir", type=str, default="./proteins_edgelists", help="Where to write edge lists")
p.add_argument("--seeds_dir", type=str, default="", help="Directory with per-graph LRMC seed JSON files")
p.add_argument("--pool_ratio", type=float, default=0.5)
p.add_argument("--hidden", type=int, default=64)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--epochs", type=int, default=200)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--cpu", action="store_true")
args = p.parse_args()
if args.export_edgelists:
export_proteins_edgelists(Path(args.data_root), Path(args.out_dir))
else:
main(args)