| |
| |
| |
|
|
| import argparse, json, hashlib |
| from pathlib import Path |
| from typing import List, Tuple, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from torch_scatter import scatter_mean |
| from torch_sparse import coalesce, spspmm |
| from torch_geometric.datasets import Planetoid |
| from torch_geometric.nn import GCNConv |
|
|
| from rich import print |
|
|
|
|
| |
| |
| |
|
|
| def add_scaled_self_loops(edge_index: Tensor, |
| edge_weight: Optional[Tensor], |
| num_nodes: int, |
| scale: float = 1.0) -> Tuple[Tensor, Tensor]: |
| if scale == 0.0: |
| if edge_weight is None: |
| edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) |
| return edge_index, edge_weight |
| device = edge_index.device |
| self_loops = torch.arange(num_nodes, device=device) |
| self_index = torch.stack([self_loops, self_loops], dim=0) |
| self_weight = torch.full((num_nodes,), float(scale), device=device) |
| base_w = edge_weight if edge_weight is not None else torch.ones(edge_index.size(1), device=device) |
| ei = torch.cat([edge_index, self_index], dim=1) |
| ew = torch.cat([base_w, self_weight], dim=0) |
| ei, ew = coalesce(ei, ew, num_nodes, num_nodes, op='add') |
| return ei, ew |
|
|
|
|
| def adjacency_power(edge_index: Tensor, num_nodes: int, k: int = 2) -> Tensor: |
| |
| row, col = edge_index |
| val = torch.ones(row.numel(), device=edge_index.device) |
| Ai, Av = edge_index, val |
| Ri, _ = spspmm(Ai, Av, Ai, Av, num_nodes, num_nodes, num_nodes) |
| mask = Ri[0] != Ri[1] |
| Ri = Ri[:, mask] |
| Ri, _ = coalesce(Ri, torch.ones(Ri.size(1), device=edge_index.device), num_nodes, num_nodes, op='add') |
| return Ri |
|
|
|
|
| def _md5(path: Path) -> str: |
| h = hashlib.md5() |
| with path.open('rb') as f: |
| for chunk in iter(lambda: f.read(8192), b''): |
| h.update(chunk) |
| return h.hexdigest() |
|
|
|
|
| |
| |
| |
|
|
| def _extract_members(cluster_obj: dict) -> List[int]: |
| m = cluster_obj.get("members", None) |
| if isinstance(m, list) and len(m) > 0: |
| return list(dict.fromkeys(int(x) for x in m)) |
| m2 = cluster_obj.get("seed_nodes", None) |
| if isinstance(m2, list) and len(m2) > 0: |
| return list(dict.fromkeys(int(x) for x in m2)) |
| if isinstance(m, list) or isinstance(m2, list): |
| return [] |
| raise KeyError("Cluster object has neither 'members' nor 'seed_nodes'.") |
|
|
|
|
| def _pick_top1_cluster(obj: dict) -> List[int]: |
| clusters = obj.get("clusters", []) |
| if not isinstance(clusters, list) or len(clusters) == 0: |
| return [] |
| def keyfun(c): |
| score = float(c.get("score", 0.0)) |
| try: |
| mem = _extract_members(c) |
| except KeyError: |
| mem = [] |
| return (score, len(mem)) |
| best = max(clusters, key=keyfun) |
| try: |
| members = _extract_members(best) |
| except KeyError: |
| members = [] |
| return sorted(set(int(x) for x in members)) |
|
|
|
|
| def refine_k_core(C_star: List[int], edge_index: Tensor, k: int = 2, rounds: int = 50) -> List[int]: |
| """Refine cluster by taking a k-core of its induced subgraph (label-free purity boost).""" |
| if k <= 0 or len(C_star) == 0: |
| return C_star |
| device = edge_index.device |
| S = torch.tensor(sorted(set(C_star)), device=device, dtype=torch.long) |
| inS = torch.zeros(int(edge_index.max().item()) + 1, dtype=torch.bool, device=device) |
| inS[S] = True |
| ei = edge_index |
| for _ in range(rounds): |
| u, v = ei[0], ei[1] |
| mask_int = inS[u] & inS[v] |
| u_int, v_int = u[mask_int], v[mask_int] |
| if u_int.numel() == 0: |
| break |
| deg = torch.zeros_like(inS, dtype=torch.long) |
| deg.scatter_add_(0, u_int, torch.ones_like(u_int, dtype=torch.long)) |
| deg.scatter_add_(0, v_int, torch.ones_like(v_int, dtype=torch.long)) |
| keep = inS.clone() |
| kill = (deg < k) & inS |
| if not kill.any(): |
| break |
| keep[kill] = False |
| if keep.sum() == inS.sum(): |
| break |
| inS = keep |
| out = torch.nonzero(inS, as_tuple=False).view(-1).tolist() |
| |
| return sorted(set(out).intersection(set(C_star))) |
|
|
|
|
| def load_top1_assignment(seeds_json: str, n_nodes: int, |
| debug: bool = False, |
| refine_k: int = 0, |
| edge_index_for_refine: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, dict]: |
| """ |
| Hard assignment for top-1 LRMC cluster with optional k-core refinement. |
| cluster 0 = top cluster; others are singletons. |
| """ |
| p = Path(seeds_json) |
| obj = json.loads(p.read_text(encoding='utf-8')) |
| C_star = _pick_top1_cluster(obj) |
| if len(C_star) > 0 and max(C_star) == n_nodes: |
| |
| C_star = [u - 1 for u in C_star] |
|
|
| if refine_k > 0: |
| if edge_index_for_refine is None: |
| raise ValueError("--refine_k requires access to edge_index for refinement.") |
| C_star = refine_k_core(C_star, edge_index_for_refine, k=refine_k) |
|
|
| C = torch.tensor(C_star, dtype=torch.long) |
| if C.numel() == 0: |
| raise RuntimeError( |
| f"No members found for top-1 cluster in {seeds_json}. " |
| f"Expected 'members' or 'seed_nodes' to be non-empty." |
| ) |
|
|
| node2cluster = torch.full((n_nodes,), -1, dtype=torch.long) |
| node2cluster[C] = 0 |
| outside = torch.tensor(sorted(set(range(n_nodes)) - set(C.tolist())), dtype=torch.long) |
| if outside.numel() > 0: |
| node2cluster[outside] = torch.arange(1, 1 + outside.numel(), dtype=torch.long) |
|
|
| K = 1 + outside.numel() |
| cluster_scores = torch.zeros(K, 1, dtype=torch.float32) |
| cluster_scores[0, 0] = 1.0 |
|
|
| info = { |
| "json_md5": _md5(p), |
| "top_cluster_size": int(C.numel()), |
| "K": int(K), |
| "n_outside": int(outside.numel()), |
| "first_members": [int(x) for x in C[:10].tolist()], |
| } |
| if debug: |
| print(f"[LRMC] Loaded {seeds_json} (md5={info['json_md5']}) | " |
| f"top_size={info['top_cluster_size']} K={info['K']} outside={info['n_outside']} " |
| f"first10={info['first_members']}") |
| return node2cluster, cluster_scores, info |
|
|
|
|
| |
| |
| |
|
|
| def _sparsify_topk(edge_index: Tensor, edge_weight: Tensor, K: int, topk: int) -> Tuple[Tensor, Tensor]: |
| """Keep per-row top-k neighbors by weight; symmetrize and coalesce.""" |
| if topk <= 0: |
| return edge_index, edge_weight |
| row, col = edge_index |
| keep = torch.zeros(edge_weight.numel(), dtype=torch.bool, device=edge_weight.device) |
| |
| for r in range(K): |
| idx = (row == r).nonzero(as_tuple=False).view(-1) |
| if idx.numel(): |
| k = min(topk, idx.numel()) |
| _, order = torch.topk(edge_weight[idx], k) |
| keep[idx[order]] = True |
| ei = edge_index[:, keep] |
| ew = edge_weight[keep] |
| |
| rev = torch.stack([ei[1], ei[0]], dim=0) |
| ei2 = torch.cat([ei, rev], dim=1) |
| ew2 = torch.cat([ew, ew], dim=0) |
| ei2, ew2 = coalesce(ei2, ew2, K, K, op='max') |
| return ei2, ew2 |
|
|
|
|
| def build_cluster_graph_mixed(edge_index_node: Tensor, |
| num_nodes: int, |
| node2cluster: Tensor, |
| use_a2: bool, |
| a2_gamma: float, |
| drop_self_loops: bool, |
| topk_per_row: int) -> Tuple[Tensor, Tensor, int]: |
| """ |
| Build A_c = S^T (A + γ A²) S, optionally drop diag, then per-row top-k sparsify. |
| """ |
| device = edge_index_node.device |
| |
| row, col = edge_index_node |
| wA = torch.ones(row.numel(), device=device) |
| e_all = edge_index_node |
| w_all = wA |
| if use_a2 and a2_gamma > 0.0: |
| A2 = adjacency_power(edge_index_node, num_nodes, k=2) |
| wA2 = torch.full((A2.size(1),), float(a2_gamma), device=device) |
| e_all = torch.cat([e_all, A2], dim=1) |
| w_all = torch.cat([w_all, wA2], dim=0) |
|
|
| |
| K = int(node2cluster.max().item()) + 1 |
| src, dst = e_all |
| csrc = node2cluster[src] |
| cdst = node2cluster[dst] |
| eC = torch.stack([csrc, cdst], dim=0) |
| eC, wC = coalesce(eC, w_all, K, K, op='add') |
|
|
| if drop_self_loops: |
| mask = eC[0] != eC[1] |
| eC, wC = eC[:, mask], wC[mask] |
|
|
| if topk_per_row > 0: |
| eC, wC = _sparsify_topk(eC, wC, K, topk_per_row) |
|
|
| return eC, wC, K |
|
|
|
|
| |
| |
| |
|
|
| class GCN2(nn.Module): |
| def __init__(self, in_dim, hid, out_dim, dropout=0.5): |
| super().__init__() |
| self.conv1 = GCNConv(in_dim, hid) |
| self.conv2 = GCNConv(hid, out_dim) |
| self.dropout = dropout |
| def forward(self, x, edge_index): |
| x = F.relu(self.conv1(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = self.conv2(x, edge_index) |
| return x |
|
|
|
|
| class OneClusterPoolGated(nn.Module): |
| """ |
| Node-GCN -> pool (means) -> Cluster-GCN over sparsified A_c -> residual gate -> Node-GCN -> logits |
| """ |
| def __init__(self, |
| in_dim: int, |
| hid: int, |
| out_dim: int, |
| node2cluster: Tensor, |
| edge_index_node: Tensor, |
| num_nodes: int, |
| self_loop_scale: float = 0.0, |
| use_a2_for_clusters: bool = False, |
| a2_gamma: float = 0.2, |
| drop_cluster_self_loops: bool = True, |
| cluster_topk: int = 24, |
| debug_header: str = ""): |
| super().__init__() |
| self.n2c = node2cluster.long() |
| self.K = int(self.n2c.max().item()) + 1 |
|
|
| |
| ei_node = edge_index_node |
| ei_node, ew_node = add_scaled_self_loops(ei_node, None, num_nodes, scale=self_loop_scale) |
| self.register_buffer("edge_index_node", ei_node) |
| self.register_buffer("edge_weight_node", ew_node) |
|
|
| |
| eC, wC, K = build_cluster_graph_mixed( |
| edge_index_node, num_nodes, self.n2c, |
| use_a2=use_a2_for_clusters, a2_gamma=a2_gamma, |
| drop_self_loops=drop_cluster_self_loops, topk_per_row=cluster_topk |
| ) |
| self.register_buffer("edge_index_c", eC) |
| self.register_buffer("edge_weight_c", wC) |
| self.K = K |
|
|
| if debug_header: |
| print(f"[POOL] {debug_header} | cluster_edges={eC.size(1)} (K={K})") |
|
|
| |
| self.gcn_node1 = GCNConv(in_dim, hid, add_self_loops=False, normalize=True) |
| self.gcn_cluster = GCNConv(hid, hid, add_self_loops=True, normalize=True) |
| self.down = nn.Linear(hid, hid) |
| self.gate = nn.Sequential(nn.Linear(2*hid, hid//2), nn.ReLU(), nn.Linear(hid//2, 1)) |
| self.lambda_logit = nn.Parameter(torch.tensor(0.0)) |
| self.gcn_node2 = GCNConv(hid, out_dim) |
|
|
| def forward(self, x: Tensor, edge_index_node: Tensor) -> Tensor: |
| |
| h1 = F.relu(self.gcn_node1(x, self.edge_index_node, self.edge_weight_node)) |
| |
| z = scatter_mean(h1, self.n2c, dim=0, dim_size=self.K) |
| |
| z2 = F.relu(self.gcn_cluster(z, self.edge_index_c, self.edge_weight_c)) |
| |
| hb = z2[self.n2c] |
| inj = self.down(hb) |
| gate_dyn = torch.sigmoid(self.gate(torch.cat([h1, inj], dim=1))) |
| lam = torch.sigmoid(self.lambda_logit) |
| alpha = lam * 1.0 + (1.0 - lam) * gate_dyn |
| h2 = h1 + alpha * inj |
| h2 = F.dropout(h2, p=0.5, training=self.training) |
| |
| out = self.gcn_node2(h2, self.edge_index_node, self.edge_weight_node) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: |
| pred = logits[mask].argmax(dim=1) |
| return (pred == y[mask]).float().mean().item() |
|
|
|
|
| def run_train_eval(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4): |
| opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) |
| best_val, best_state = 0.0, None |
| for ep in range(1, epochs + 1): |
| model.train() |
| opt.zero_grad(set_to_none=True) |
| logits = model(data.x, data.edge_index) |
| loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask]) |
| loss.backward(); opt.step() |
|
|
| model.eval() |
| logits = model(data.x, data.edge_index) |
| val_acc = accuracy(logits, data.y, data.val_mask) |
| if val_acc > best_val: |
| best_val, best_state = val_acc, {k: v.detach().clone() for k, v in model.state_dict().items()} |
| if ep % 20 == 0: |
| tr = accuracy(logits, data.y, data.train_mask) |
| te = accuracy(logits, data.y, data.test_mask) |
| print(f"[{ep:04d}] loss={loss.item():.4f} train={tr:.3f} val={val_acc:.3f} test={te:.3f}") |
|
|
| if best_state is not None: |
| model.load_state_dict(best_state) |
| model.eval() |
| logits = model(data.x, data.edge_index) |
| return {"val": accuracy(logits, data.y, data.val_mask), |
| "test": accuracy(logits, data.y, data.test_mask)} |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) |
| ap.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (single large graph).") |
| ap.add_argument("--variant", choices=["baseline", "pool"], default="pool") |
| ap.add_argument("--hidden", type=int, default=128) |
| ap.add_argument("--epochs", type=int, default=200) |
| ap.add_argument("--lr", type=float, default=0.01) |
| ap.add_argument("--wd", type=float, default=5e-4) |
| ap.add_argument("--dropout", type=float, default=0.5) |
| ap.add_argument("--self_loop_scale", type=float, default=0.0) |
|
|
| |
| ap.add_argument("--use_a2", action="store_true", help="Include A^2 in cluster graph.") |
| ap.add_argument("--a2_gamma", type=float, default=0.2, help="Weight for A^2 in A + γA^2.") |
| ap.add_argument("--cluster_topk", type=int, default=24, help="Top-k neighbors per cluster row to keep.") |
| ap.add_argument("--drop_cluster_self_loops", action="store_true", help="Drop (c,c) in cluster graph.") |
| ap.add_argument("--refine_k", type=int, default=0, help="k-core refinement on the top cluster (e.g., 2).") |
|
|
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--debug", action="store_true") |
| args = ap.parse_args() |
|
|
| torch.manual_seed(args.seed) |
|
|
| ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) |
| data = ds[0] |
| in_dim, out_dim, n = ds.num_node_features, ds.num_classes, data.num_nodes |
|
|
| if args.variant == "baseline": |
| model = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout) |
| res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) |
| print(f"Baseline GCN: val={res['val']:.4f} test={res['test']:.4f}") |
| return |
|
|
| |
| node2cluster, _, info = load_top1_assignment( |
| args.seeds, n, debug=args.debug, refine_k=args.refine_k, edge_index_for_refine=data.edge_index |
| ) |
| dbg_header = f"seeds_md5={info['json_md5']} top_size={info['top_cluster_size']} K={info['K']}" |
|
|
| model = OneClusterPoolGated( |
| in_dim=in_dim, |
| hid=args.hidden, |
| out_dim=out_dim, |
| node2cluster=node2cluster, |
| edge_index_node=data.edge_index, |
| num_nodes=n, |
| self_loop_scale=args.self_loop_scale, |
| use_a2_for_clusters=args.use_a2, |
| a2_gamma=args.a2_gamma, |
| drop_cluster_self_loops=args.drop_cluster_self_loops, |
| cluster_topk=args.cluster_topk, |
| debug_header=dbg_header |
| ) |
|
|
| res = run_train_eval(model, data, epochs=args.epochs, lr=args.lr, wd=args.wd) |
| print(f"L-RMC (top-1 pool, gated): val={res['val']:.4f} test={res['test']:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|