| """ |
| Export a PyG dataset (Planetoid or Heterophilous) to a minimal Graphviz DOT. |
| - Nodes are colored by class. |
| - If --filter {train|val|test} is set, nodes in that split are colored red, |
| other nodes are left uncolored. |
| - Undirected edges are deduplicated; directed edges are written as‑is. |
| - For HeterophilousGraphDataset with multiple splits (e.g., Amazon‑ratings has 10), |
| use --split-index (default 0). |
| """ |
|
|
| import argparse |
| from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset |
|
|
| PALETTE = [ |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", |
| "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", |
| "#bcbd22", "#17becf" |
| ] |
| HIGHLIGHT_RED = "#ff0000" |
|
|
| def _extend_palette(base, need): |
| if need <= len(base): |
| return base[:need] |
| def brighten(hex_color, factor): |
| c = int(hex_color[1:], 16) |
| r = (c >> 16) & 255 |
| g = (c >> 8) & 255 |
| b = c & 255 |
| r = int(min(255, r + (255 - r) * factor)) |
| g = int(min(255, g + (255 - g) * factor)) |
| b = int(min(255, b + (255 - b) * factor)) |
| return f"#{r:02x}{g:02x}{b:02x}" |
| out = [] |
| for i in range(need): |
| base_hex = base[i % len(base)] |
| factor = 0.18 * (i // len(base)) |
| out.append(brighten(base_hex, factor)) |
| return out |
|
|
| def _infer_num_classes(dataset, data): |
| num_classes = getattr(dataset, "num_classes", None) |
| if not isinstance(num_classes, int) or num_classes <= 0: |
| num_classes = int(data.y.max().item()) + 1 |
| return num_classes |
|
|
| def _get_split_mask(data, which): |
| name = {"train": "train_mask", "val": "val_mask", "test": "test_mask"}[which] |
| m = getattr(data, name, None) |
| if m is None and which == "val": |
| m = getattr(data, "valid_mask", None) or getattr(data, "validation_mask", None) |
| return m |
|
|
| def load_graph(root: str, use_hetero: bool, name: str, split_index: int): |
| if use_hetero: |
| ds = HeterophilousGraphDataset(root=root, name=name) |
| idx = max(0, min(split_index, len(ds) - 1)) |
| data = ds[idx] |
| num_classes = _infer_num_classes(ds, data) |
| return data, num_classes |
| else: |
| ds = Planetoid(root=f"{root}/Planetoid", name=name) |
| data = ds[0] |
| num_classes = _infer_num_classes(ds, data) |
| return data, num_classes |
|
|
| def write_dot(path: str, data, num_classes: int, directed: bool, filter_split: str | None): |
| y = data.y |
| edge_index = data.edge_index |
| colors = _extend_palette(PALETTE, num_classes) |
| highlight_mask = None |
| if filter_split is not None: |
| m = _get_split_mask(data, filter_split) |
| if m is not None: |
| highlight_mask = m.bool() |
| gtype = "digraph" if directed else "graph" |
| eop = "->" if directed else "--" |
| with open(path, "w", encoding="utf-8") as f: |
| f.write(f"{gtype} {{\n") |
| for i in range(data.num_nodes): |
| cls = int(y[i]) |
| if cls < 0 or cls >= num_classes: |
| cls = cls % num_classes |
| base_col = colors[cls] |
| if filter_split is not None: |
| if highlight_mask is not None and bool(highlight_mask[i]): |
| col = HIGHLIGHT_RED |
| f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') |
| else: |
| f.write(f' {i} ;\n') |
| else: |
| col = base_col |
| f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') |
| if directed: |
| for s, t in edge_index.t().tolist(): |
| if s != t: |
| f.write(f" {s} {eop} {t};\n") |
| else: |
| seen = set() |
| for s, t in edge_index.t().tolist(): |
| if s == t: |
| continue |
| a, b = (s, t) if s <= t else (t, s) |
| if (a, b) in seen: |
| continue |
| seen.add((a, b)) |
| f.write(f" {a} {eop} {b};\n") |
| f.write("}\n") |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Export PyG datasets to minimal DOT with class colors and optional split highlighting." |
| ) |
| parser.add_argument("-o", "--output", default="graph.dot", help="Output .dot file (default: graph.dot)") |
| parser.add_argument("--directed", action="store_true", help="Write directed edges (default: undirected)") |
| parser.add_argument("--heterophilous", action="store_true", help="Use HeterophilousGraphDataset") |
| parser.add_argument("--name", default=None, help="Dataset name (Planetoid: Cora/CiteSeer/PubMed; Heterophilous: Amazon-ratings, Roman-empire, etc.)") |
| parser.add_argument("--root", default="data", help="Root folder for datasets (default: data)") |
| parser.add_argument("--split-index", type=int, default=0, help="Split index for heterophilous datasets (default: 0)") |
| parser.add_argument("--filter", choices=["train", "val", "test"], default=None, help="Highlight nodes in the selected split as red (others keep class colors)") |
| args = parser.parse_args() |
| dataset_name = args.name if args.name is not None else ("Amazon-ratings" if args.heterophilous else "Cora") |
| data, num_classes = load_graph(args.root, args.heterophilous, dataset_name, args.split_index) |
| write_dot(args.output, data, num_classes, args.directed, args.filter) |
| suffix = f", highlight={args.filter}" if args.filter else "" |
| print(f"Wrote {args.output} using {'HeterophilousGraphDataset' if args.heterophilous else 'Planetoid'}('{dataset_name}') | nodes={data.num_nodes}, edges={data.edge_index.size(1)}, classes={num_classes}{suffix}") |
|
|
| if __name__ == "__main__": |
| main() |