| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import argparse, os, json, math, random |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import random_split |
| from torch_geometric.datasets import TUDataset |
| from torch_geometric.loader import DataLoader |
| from torch_geometric.nn import GCNConv, global_mean_pool, TopKPooling |
| from torch_geometric.utils import to_dense_adj, add_self_loops, remove_self_loops, coalesce, subgraph |
|
|
| Device = torch.device |
|
|
| |
| |
| |
| class DatasetWithGID(torch.utils.data.Dataset): |
| """Wraps a PyG dataset and attaches a per-item global id as a tensor |
| so that Batch will collate it into data.gid (shape [num_graphs, 1]).""" |
| def __init__(self, base): |
| self.base = base |
| |
| for attr in ("num_classes", "num_features"): |
| if hasattr(base, attr): |
| setattr(self, attr, getattr(base, attr)) |
| def __len__(self): return len(self.base) |
| def __getitem__(self, idx): |
| data = self.base[idx] |
| data.gid = torch.tensor([idx], dtype=torch.long) |
| return data |
|
|
| def set_seed(seed: int): |
| random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| def export_proteins_edgelists(root: Path, out_dir: Path): |
| ds = TUDataset(root=str(root), name='PROTEINS') |
| out_dir.mkdir(parents=True, exist_ok=True) |
| for i, data in enumerate(ds): |
| n = int(data.num_nodes) |
| ei = data.edge_index |
| |
| with (out_dir / f"graph_{i:06d}.txt").open('w') as f: |
| for u, v in ei.t().tolist(): |
| f.write(f"{u+1} {v+1}\n") |
| print(f"[export] Wrote {len(ds)} edge lists to {out_dir}") |
|
|
| def load_lrmc_seeds_dir(seeds_dir: Optional[Path]) -> Optional[List[List[List[int]]]]: |
| """Return list indexed by graph idx -> list of clusters -> list of node indices. |
| Expect files named graph_000000.json OR i.json with structure {"clusters":[{"members":[...]}, ...]} |
| """ |
| if seeds_dir is None: return None |
| by_graph: Dict[int, List[List[int]]] = {} |
| for p in sorted(seeds_dir.glob("*.json")): |
| stem = p.stem |
| try: |
| gi = int(stem.split('_')[-1]) if stem.startswith("graph_") else int(stem) |
| except: |
| continue |
| obj = json.loads(p.read_text()) |
| clusters = [] |
| for c in obj.get("clusters", []): |
| mem = c.get("members") or c.get("nodes") or [] |
| clusters.append([int(x) for x in mem]) |
| by_graph[gi] = clusters |
| if not by_graph: |
| print(f"[warn] no seed jsons found in {seeds_dir}") |
| return None |
| |
| max_i = max(by_graph.keys()) |
| out: List[Optional[List[List[int]]]] = [None]*(max_i+1) |
| for i, clusters in by_graph.items(): |
| out[i] = clusters |
| return out |
|
|
| |
| |
| |
| class PlainGCN(nn.Module): |
| def __init__(self, in_dim: int, hidden: int, num_classes: int): |
| super().__init__() |
| self.conv1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) |
| self.conv2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True) |
| self.lin = nn.Linear(hidden, num_classes) |
|
|
| def forward(self, data): |
| x, edge_index, batch = data.x, data.edge_index, data.batch |
| x = F.relu(self.conv1(x, edge_index)) |
| x = F.relu(self.conv2(x, edge_index)) |
| x = global_mean_pool(x, batch) |
| return self.lin(x) |
|
|
| class gPoolNet(nn.Module): |
| def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float): |
| super().__init__() |
| self.conv1 = GCNConv(in_dim, hidden) |
| self.pool1 = TopKPooling(hidden, ratio=ratio) |
| self.conv2 = GCNConv(hidden, hidden) |
| self.lin = nn.Linear(hidden, num_classes) |
|
|
| def forward(self, data): |
| x, edge_index, batch = data.x, data.edge_index, data.batch |
| x = F.relu(self.conv1(x, edge_index)) |
| x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch) |
| x = F.relu(self.conv2(x, edge_index)) |
| x = global_mean_pool(x, batch) |
| return self.lin(x) |
|
|
| class DiffPoolOneShot(nn.Module): |
| |
| def __init__(self, in_dim: int, hidden: int, num_classes: int, ratio: float): |
| super().__init__() |
| self.gnn_embed = GCNConv(in_dim, hidden) |
| self.assign = nn.Linear(hidden, |
| max(1, int(round(ratio * 50)))) |
| self.post_conv = GCNConv(hidden, hidden) |
| self.lin = nn.Linear(hidden, num_classes) |
| self.ratio = ratio |
|
|
| def forward(self, data): |
| |
| x, edge_index, batch = data.x, data.edge_index, data.batch |
| out = [] |
| for gi in batch.unique().tolist(): |
| mask = (batch == gi) |
| xi = x[mask] |
| |
| edges_mask = (batch[edge_index[0]] == gi) & (batch[edge_index[1]] == gi) |
| eidx_g = edge_index[:, edges_mask] |
| idx = torch.where(mask)[0] |
| n_i = idx.numel() |
| local = torch.full((batch.size(0),), -1, dtype=torch.long, device=batch.device) |
| local[idx] = torch.arange(n_i, device=batch.device) |
| eidx_i = local[eidx_g] |
|
|
| |
| zi = F.relu(self.gnn_embed(xi, eidx_i)) |
| K_i = max(1, int(round(self.ratio * n_i))) |
| Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) |
| Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] |
| Zi = Si.t() @ zi |
| Ai_pooled = Si.t() @ Ai @ Si |
| |
| x, edge_index, batch = data.x, data.edge_index, data.batch |
| |
| gids = data.gid.view(-1) |
| out = [] |
| for gi_local, gi_global in enumerate(gids.tolist()): |
| subset_idx = torch.where(batch == gi_local)[0] |
| if subset_idx.numel() == 0: |
| continue |
| |
| eidx_i, _ = subgraph(subset_idx, edge_index, relabel_nodes=True, num_nodes=x.size(0)) |
| xi = x[subset_idx] |
| |
| zi = F.relu(self.gnn_embed(xi, eidx_i)) |
| n_i = zi.size(0) |
| K_i = max(1, int(round(self.ratio * n_i))) |
| Si = F.softmax(self.assign(zi)[:, :K_i], dim=-1) |
| Ai = to_dense_adj(eidx_i, max_num_nodes=n_i)[0] |
| Zi = Si.t() @ zi |
| Ai_pooled = Si.t() @ Ai @ Si |
| |
| rows, cols = (Ai_pooled > 0).nonzero(as_tuple=True) |
| if rows.numel() == 0: |
| |
| edge_index_coarse = torch.stack([torch.arange(K_i, device=Zi.device), torch.arange(K_i, device=Zi.device)]) |
| edge_weight = None |
| else: |
| edge_index_coarse = torch.stack([rows, cols], dim=0) |
| edge_weight = Ai_pooled[rows, cols] |
| Zi2 = F.relu(self.post_conv(Zi, edge_index_coarse, edge_weight)) |
| out.append(Zi2.mean(dim=0, keepdim=True)) |
| Xg = torch.cat(out, dim=0) if out else torch.zeros(0, self.lin.in_features, device=x.device) |
| return self.lin(Xg) |
|
|
|
|
| class LRMCPoolNet(nn.Module): |
| """Hard cluster assignment from LRMC seeds per graph; falls back to no pooling if seeds missing.""" |
|
|
| def __init__(self, in_dim: int, hidden: int, num_classes: int, pool_ratio: float, |
| seeds_by_graph: Optional[List[Optional[List[List[int]]]]] = None): |
| super().__init__() |
| self.conv1 = GCNConv(in_dim, hidden) |
| self.conv2 = GCNConv(hidden, hidden) |
| self.post_conv = GCNConv(hidden, hidden) |
| self.lin = nn.Linear(hidden, num_classes) |
| self.pool_ratio = pool_ratio |
| self.seeds_by_graph = seeds_by_graph |
|
|
| def forward(self, data): |
| x, edge_index, batch = data.x, data.edge_index, data.batch |
| z = F.relu(self.conv1(x, edge_index)) |
| z = F.relu(self.conv2(z, edge_index)) |
|
|
| |
| gids = data.gid.view(-1) |
| pooled = [] |
| for gi_local, gi_global in enumerate(gids.tolist()): |
| mask = (batch == gi_local) |
| zi = z[mask] |
| |
| idx_nodes = torch.where(mask)[0] |
| local = -torch.ones(z.size(0), dtype=torch.long, device=z.device) |
| local[idx_nodes] = torch.arange(idx_nodes.size(0), device=z.device) |
| edge_mask = (batch[edge_index[0]] == gi_local) & (batch[edge_index[1]] == gi_local) |
| eidx_i = local[edge_index[:, edge_mask]] |
| n_i = zi.size(0) |
|
|
| clusters = None |
| if self.seeds_by_graph is not None and 0 <= gi_global < len(self.seeds_by_graph): |
| clusters = self.seeds_by_graph[gi_global] |
|
|
| if not hasattr(self, "_dbg"): |
| self._dbg = True |
| print(f"[LRMC] used_seeds={clusters is not None} K={len(clusters) if clusters else 0} n={n_i}") |
|
|
| |
| valid = True |
| if clusters: |
| cleaned = [] |
| for mem in clusters: |
| if not mem: |
| continue |
| mt = torch.tensor(mem, dtype=torch.long, device=z.device) |
| |
| if mt.min().item() >= 1 and mt.max().item() <= n_i and (mt == 0).sum().item() == 0: |
| mt = mt - 1 |
| |
| if (mt < 0).any() or (mt >= n_i).any(): |
| valid = False |
| break |
| cleaned.append(mt.tolist()) |
| clusters = cleaned if valid else None |
|
|
| if not clusters: |
| |
| pooled.append(zi.mean(dim=0, keepdim=True)) |
| continue |
|
|
| |
| targetK = max(1, int(round(self.pool_ratio * n_i))) |
| clusters = [sorted(set(c)) for c in clusters if len(c) >= 1] |
| if len(clusters) > targetK: |
| clusters = sorted(clusters, key=len, reverse=True)[:targetK] |
|
|
| |
| cid = -torch.ones(n_i, dtype=torch.long, device=z.device) |
| bad = False |
| for k, mem in enumerate(clusters): |
| mem_t = torch.tensor(mem, dtype=torch.long, device=z.device) |
| if mem_t.numel() > 0 and mem_t.min().item() >= 1 and mem_t.max().item() <= n_i and ( |
| mem_t == 0).sum().item() == 0: |
| mem_t = mem_t - 1 |
| |
| if mem_t.numel() and (mem_t.min().item() < 0 or mem_t.max().item() >= n_i): |
| bad = True |
| break |
| cid[mem_t] = k |
| if bad: |
| |
| if not hasattr(self, "_bad_warned"): |
| self._bad_warned = set() |
| if gi_global not in self._bad_warned: |
| print( |
| f"[LRMC] Warning: seed indices out of range for graph gid={gi_global} (n={n_i}); falling back to mean pooling for this graph.") |
| self._bad_warned.add(gi_global) |
| pooled.append(zi.mean(dim=0, keepdim=True)) |
| continue |
|
|
| |
| for u in torch.where(cid < 0)[0].tolist(): |
| clusters.append([int(u)]) |
| cid[u] = len(clusters) - 1 |
| if len(clusters) >= targetK: |
| break |
| if (cid < 0).any(): |
| cid[cid < 0] = len(clusters) - 1 |
|
|
| |
| Zi = torch.zeros(len(clusters), z.size(1), device=z.device) |
| Zi.index_add_(0, cid, zi) |
| counts = torch.bincount(cid, minlength=len(clusters)).clamp(min=1).view(-1, 1).to(Zi.dtype) |
| Zi = Zi / counts |
| |
| e_src = cid[eidx_i[0]] |
| e_dst = cid[eidx_i[1]] |
| e_coarse, _ = coalesce(torch.stack([e_src, e_dst], dim=0), None, len(clusters), len(clusters)) |
| |
| Zi2 = F.relu(self.post_conv(Zi, e_coarse)) |
| pooled.append(Zi2.mean(dim=0, keepdim=True)) |
| Xg = torch.cat(pooled, dim=0) |
| return self.lin(Xg) |
|
|
| |
| |
| |
| def train_epoch(model, loader, device, opt): |
| model.train() |
| total = 0.0 |
| for data in loader: |
| data = data.to(device) |
| opt.zero_grad() |
| out = model(data) |
| loss = F.cross_entropy(out, data.y) |
| loss.backward() |
| opt.step() |
| total += loss.detach().item() * data.num_graphs |
| return total / len(loader.dataset) |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| model.eval() |
| correct = 0 |
| total = 0 |
| for data in loader: |
| data = data.to(device) |
| pred = model(data).argmax(dim=-1) |
| correct += int((pred == data.y).sum()) |
| total += data.num_graphs |
| return correct / total |
|
|
| |
| |
| |
| def main(args): |
| set_seed(args.seed) |
| device = torch.device('cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') |
|
|
| |
| base = TUDataset(root=str(args.data_root), name='PROTEINS') |
| ds = DatasetWithGID(base) |
|
|
| |
| for i in range(len(ds)): |
| ds[i].gid = torch.tensor([i], dtype=torch.long) |
|
|
| num_classes = base.num_classes |
| in_dim = base.num_features if base.num_features and base.num_features > 0 else 1 |
| |
| if in_dim == 0: |
| from torch_geometric.utils import degree |
| for data in base: |
| deg = degree(data.edge_index[0], num_nodes=data.num_nodes).view(-1,1) |
| data.x = deg |
| in_dim = 1 |
|
|
| |
| N = len(ds) |
| n_train = int(0.8 * N); n_val = int(0.1 * N); n_test = N - n_train - n_val |
| train_set, val_set, test_set = random_split(ds, [n_train, n_val, n_test], |
| generator=torch.Generator().manual_seed(args.seed)) |
| train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) |
| val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) |
| test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) |
|
|
| seeds_by_graph = load_lrmc_seeds_dir(Path(args.seeds_dir)) if args.seeds_dir else None |
|
|
| |
| hidden = args.hidden |
| plain = PlainGCN(in_dim, hidden, num_classes).to(device) |
| lrmc = LRMCPoolNet(in_dim, hidden, num_classes, pool_ratio=args.pool_ratio, |
| seeds_by_graph=seeds_by_graph).to(device) |
| gpool = gPoolNet(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device) |
| diffp = DiffPoolOneShot(in_dim, hidden, num_classes, ratio=args.pool_ratio).to(device) |
|
|
| |
| for name, model in [('PlainGCN', plain), ('L-RMC', lrmc), ('gPool', gpool), ('DiffPool', diffp)]: |
| opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) |
| best_val = 0.0; best_state = None |
| for epoch in range(1, args.epochs+1): |
| loss = train_epoch(model, train_loader, device, opt) |
| acc_val = evaluate(model, val_loader, device) |
| if acc_val >= best_val: |
| best_val = acc_val; best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
| if epoch % 20 == 0 or epoch == args.epochs: |
| print(f"{name} epoch {epoch}: loss={loss:.4f}, val_acc={acc_val:.3f}") |
| if best_state is not None: model.load_state_dict(best_state, strict=False) |
| acc_test = evaluate(model, test_loader, device) |
| print(f"{name:8s} test_acc={acc_test:.3f}") |
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--data_root", type=str, default="./data", help="Root dir for TUDataset(PROTEINS)") |
| p.add_argument("--export_edgelists", action="store_true", help="Export per-graph edgelists for Java LRMC seeder") |
| p.add_argument("--out_dir", type=str, default="./proteins_edgelists", help="Where to write edge lists") |
| p.add_argument("--seeds_dir", type=str, default="", help="Directory with per-graph LRMC seed JSON files") |
| p.add_argument("--pool_ratio", type=float, default=0.5) |
| p.add_argument("--hidden", type=int, default=64) |
| p.add_argument("--batch_size", type=int, default=64) |
| p.add_argument("--epochs", type=int, default=200) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--cpu", action="store_true") |
| args = p.parse_args() |
|
|
| if args.export_edgelists: |
| export_proteins_edgelists(Path(args.data_root), Path(args.out_dir)) |
| else: |
| main(args) |
|
|