{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "2eq2Z1JGYLy7" }, "source": [ "Assignment 2 : Create ML model based on Anonymous Walk Embeddings for node level prediction" ] }, { "cell_type": "markdown", "source": [ "- Pramod Manohar Dalavi - 2023aa05398@wilp.bits-pilani.ac.in\n", "- Utkarsh Kumar Verma - 2023ab05014@wilp.bits-pilani.ac.in\n", "- Ankita Laxmikant Bahirat - 2023aa05952@wilp.bits-pilani.ac.in\n", "- Charu Mathur - 2023aa05055@wilp.bits-pilani.ac.in\n", "- K Mamatha - 2023ab05018@wilp.bits-pilani.ac.in" ], "metadata": { "id": "UZPVjT1hwxOF" } }, { "cell_type": "markdown", "metadata": { "id": "QKRg53KHYUGc" }, "source": [ "Generate Graph Embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S0YbS1rMYKQg", "outputId": "57474ad6-f275-424a-e3ca-01cb32fea82e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[[0, 1, 2, 0, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 2], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 2, 1, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 1], [0, 1, 2, 3, 2], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 2], [0, 1, 2, 0, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 0], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 1], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 0], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 0, 1, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 0, 2], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 2], [0, 1, 0, 2, 0], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 1], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 0, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 0], [0, 1, 2, 3, 2], [0, 1, 0, 2, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 0, 2], [0, 1, 0, 1, 2], [0, 1, 2, 3, 1], [0, 1, 0, 1, 2], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 1, 3], [0, 1, 2, 1, 3], [0, 1, 2, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 2], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 2], [0, 1, 2, 1, 3], [0, 1, 2, 1, 2], [0, 1, 0, 1, 0], [0, 1, 2, 1, 0], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 1], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 0, 2], [0, 1, 2, 3, 4], [0, 1, 2, 0, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 0, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 0, 1, 2], [0, 1, 2, 1, 3], [0, 1, 2, 3, 0], [0, 1, 0, 2, 3], [0, 1, 2, 3, 0], [0, 1, 0, 1, 0], [0, 1, 2, 0, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 0], [0, 1, 2, 0, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 0, 1, 0], [0, 1, 2, 1, 0], [0, 1, 2, 1, 0], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 2, 1, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 1, 3], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 0], [0, 1, 2, 3, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 1, 0], [0, 1, 2, 3, 4], [0, 1, 2, 1, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 1, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 1], [0, 1, 2, 1, 3], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 0, 2, 0], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 2, 1, 0], [0, 1, 0, 2, 3], [0, 1, 2, 3, 1], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 0, 2, 0], [0, 1, 2, 3, 2], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 0, 1, 0], [0, 1, 2, 1, 3], [0, 1, 0, 2, 3], [0, 1, 0, 1, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 1], [0, 1, 0, 2, 0], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 0, 2, 1], [0, 1, 0, 2, 0], [0, 1, 0, 2, 3], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 0], [0, 1, 2, 3, 4], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 0, 1], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 1], [0, 1, 2, 3, 2], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 1, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2], [0, 1, 0, 2, 3], [0, 1, 2, 1, 2], [0, 1, 0, 2, 3], [0, 1, 2, 3, 4], [0, 1, 2, 3, 2]]\n" ] } ], "source": [ "import networkx as nx\n", "import numpy as np\n", "\n", "def generate_anonymous_walks(graph, walk_length, num_walks):\n", " walks = []\n", " for _ in range(num_walks):\n", " for node in graph.nodes():\n", " walk = [node]\n", " for _ in range(walk_length - 1):\n", " neighbors = list(graph.neighbors(walk[-1]))\n", " if neighbors:\n", " walk.append(np.random.choice(neighbors))\n", " else:\n", " break\n", " walks.append(walk)\n", " return walks\n", "\n", "def anonymous_walk_embedding(graph, walk_length=5, num_walks=10):\n", " walks = generate_anonymous_walks(graph, walk_length, num_walks)\n", " # Convert walks to anonymous walks\n", " anon_walks = []\n", " for walk in walks:\n", " anon_walk = []\n", " mapping = {}\n", " next_id = 0\n", " for node in walk:\n", " if node not in mapping:\n", " mapping[node] = next_id\n", " next_id += 1\n", " anon_walk.append(mapping[node])\n", " anon_walks.append(anon_walk)\n", " return anon_walks\n", "\n", "# Example usage\n", "G = nx.karate_club_graph()\n", "embeddings = anonymous_walk_embedding(G)\n", "print(embeddings)" ] }, { "cell_type": "markdown", "metadata": { "id": "fgMkAPGpYX7A" }, "source": [ "Dataset Preparation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4eiZ0-p0Y_HY", "outputId": "573d7c65-b94d-4168-90c8-48429fa837e6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting ogb\n", " Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)\n", "Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.5.1+cu124)\n", "Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.26.4)\n", "Requirement already satisfied: tqdm>=4.29.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (4.67.1)\n", "Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.6.1)\n", "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.2.2)\n", "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.17.0)\n", "Requirement already satisfied: urllib3>=1.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.3.0)\n", "Collecting outdated>=0.2.0 (from ogb)\n", " Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)\n", "Requirement already satisfied: setuptools>=44 in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (75.1.0)\n", "Collecting littleutils (from outdated>=0.2.0->ogb)\n", " Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (2.32.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.1)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.14.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (3.5.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.17.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (2024.10.0)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch>=1.6.0->ogb)\n", " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.6.0->ogb) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.6.0->ogb) (3.0.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (3.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->outdated>=0.2.0->ogb) (2025.1.31)\n", "Downloading ogb-1.3.6-py3-none-any.whl (78 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.8/78.8 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)\n", "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m62.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m48.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m35.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading littleutils-0.2.4-py3-none-any.whl (8.1 kB)\n", "Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, littleutils, outdated, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, ogb\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", " Attempting uninstall: nvidia-curand-cu12\n", " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", " Attempting uninstall: nvidia-cufft-cu12\n", " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", " Attempting uninstall: nvidia-cuda-runtime-cu12\n", " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-cupti-cu12\n", " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", " Attempting uninstall: nvidia-cublas-cu12\n", " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", " Attempting uninstall: nvidia-cusparse-cu12\n", " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", " Attempting uninstall: nvidia-cudnn-cu12\n", " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", " Attempting uninstall: nvidia-cusolver-cu12\n", " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", "Successfully installed littleutils-0.2.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 ogb-1.3.6 outdated-0.2.2\n", "Requirement already satisfied: ogb in /usr/local/lib/python3.11/dist-packages (1.3.6)\n", "Collecting torch-geometric\n", " Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.1/63.1 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.5.1+cu124)\n", "Requirement already satisfied: numpy>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.26.4)\n", "Requirement already satisfied: tqdm>=4.29.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (4.67.1)\n", "Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.6.1)\n", "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.2.2)\n", "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (1.17.0)\n", "Requirement already satisfied: urllib3>=1.24.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (2.3.0)\n", "Requirement already satisfied: outdated>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from ogb) (0.2.2)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.11.13)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2024.10.0)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.1.6)\n", "Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (5.9.5)\n", "Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (3.2.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from torch-geometric) (2.32.3)\n", "Requirement already satisfied: setuptools>=44 in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (75.1.0)\n", "Requirement already satisfied: littleutils in /usr/local/lib/python3.11/dist-packages (from outdated>=0.2.0->ogb) (0.2.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=0.24.0->ogb) (2025.1)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.14.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.20.0->ogb) (3.5.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.17.0)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.4.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.3.1.170)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (12.4.127)\n", "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (3.1.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.6.0->ogb) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=1.6.0->ogb) (1.3.0)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (2.5.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.3.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (25.1.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (0.3.0)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->torch-geometric) (1.18.3)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch-geometric) (3.0.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (3.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->torch-geometric) (2025.1.31)\n", "Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m22.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: torch-geometric\n", "Successfully installed torch-geometric-2.6.1\n", "1.3.6\n", "2.6.1\n" ] } ], "source": [ "!pip install ogb\n", "!pip install ogb torch-geometric\n", "import ogb\n", "import torch_geometric\n", "\n", "print(ogb.__version__)\n", "print(torch_geometric.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4B8MF6xVZ6qa", "outputId": "30a2c8a5-15f6-4304-a8ec-2bacddefb7a2" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "1.3.6\n", "2.6.1\n" ] } ], "source": [ "import ogb\n", "import torch_geometric\n", "\n", "print(ogb.__version__)\n", "print(torch_geometric.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mAQtd8gjaBb0", "outputId": "e948110b-f189-4696-bf8a-437ec7ddab3e" }, "outputs": [ { "metadata": { "tags": null }, "name": "stdout", "output_type": "stream", "text": [ "Found existing installation: ogb 1.3.6\n", "Uninstalling ogb-1.3.6:\n", " Would remove:\n", " /usr/local/lib/python3.11/dist-packages/ogb-1.3.6.dist-info/*\n", " /usr/local/lib/python3.11/dist-packages/ogb/*\n", "Proceed (Y/n)? " ] } ], "source": [ "#!pip uninstall ogb torch-geometric\n", "#!pip install ogb torch-geometric" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZKT-dMUbYZj1" }, "outputs": [], "source": [ "import torch\n", "from ogb.graphproppred import PygGraphPropPredDataset\n", "from torch_geometric.data import DataLoader\n", "\n", "# Load dataset\n", "dataset = PygGraphPropPredDataset(name=\"ogbg-molhiv\")\n", "split_idx = dataset.get_idx_split()\n", "train_loader = DataLoader(dataset[split_idx[\"train\"]], batch_size=32, shuffle=True)\n", "valid_loader = DataLoader(dataset[split_idx[\"valid\"]], batch_size=32, shuffle=False)\n", "test_loader = DataLoader(dataset[split_idx[\"test\"]], batch_size=32, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "pdNjBmBJYcOP" }, "source": [ "Neural Network Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5tTY5vQ3Yd3y" }, "outputs": [], "source": [ "import torch\n", "from torch_geometric.nn import GCNConv, global_mean_pool\n", "\n", "class GNN(torch.nn.Module):\n", " def __init__(self, hidden_channels):\n", " super(GNN, self).__init__() # Call the parent class's __init__ method\n", " self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)\n", " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", " self.lin = torch.nn.Linear(hidden_channels, dataset.num_tasks)\n", "\n", " def forward(self, x, edge_index, batch):\n", " x = self.conv1(x, edge_index).relu()\n", " x = self.conv2(x, edge_index).relu()\n", " x = global_mean_pool(x, batch) # Global pooling\n", " x = self.lin(x)\n", " return x\n", "\n", "model = GNN(hidden_channels=64)" ] }, { "cell_type": "markdown", "metadata": { "id": "Ot4ypsx1YgBq" }, "source": [ "Model Optimization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ldHLgvdJYh86" }, "outputs": [], "source": [ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "criterion = torch.nn.BCEWithLogitsLoss()\n", "\n", "def train():\n", " model.train()\n", " for data in train_loader:\n", " optimizer.zero_grad()\n", " out = model(data.x, data.edge_index, data.batch)\n", " loss = criterion(out, data.y.float())\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "id": "9MWIVvcRYlpB" }, "source": [ "Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ui9ln9PlbzZD" }, "outputs": [], "source": [ "# Inspect the first element of the dataset\n", "sample_data = dataset[0]\n", "print(sample_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true }, "id": "J-gzaiejcpvx" }, "outputs": [], "source": [ "import torch\n", "from ogb.graphproppred import GraphPropPredDataset, Evaluator\n", "from torch_geometric.data import DataLoader, Data\n", "from torch_geometric.nn import GCNConv, global_mean_pool\n", "\n", "# Load dataset\n", "dataset = GraphPropPredDataset(name=\"ogbg-molhiv\")\n", "split_idx = dataset.get_idx_split()\n", "\n", "# Convert split indices to lists of integers\n", "train_idx = split_idx[\"train\"].tolist()\n", "valid_idx = split_idx[\"valid\"].tolist()\n", "test_idx = split_idx[\"test\"].tolist()\n", "\n", "# Convert dataset to PyTorch Geometric Data objects\n", "def convert_to_pyg_data(data_dict, label):\n", " return Data(\n", " x=torch.tensor(data_dict['node_feat'], dtype=torch.float),\n", " edge_index=torch.tensor(data_dict['edge_index'], dtype=torch.long),\n", " edge_attr=torch.tensor(data_dict['edge_feat'], dtype=torch.float),\n", " y=torch.tensor(label, dtype=torch.float).view(-1, 1) # Reshape the label\n", " )\n", "\n", "train_data = [convert_to_pyg_data(dataset[i][0], dataset[i][1]) for i in train_idx]\n", "valid_data = [convert_to_pyg_data(dataset[i][0], dataset[i][1]) for i in valid_idx]\n", "test_data = [convert_to_pyg_data(dataset[i][0], dataset[i][1]) for i in test_idx]\n", "\n", "train_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n", "valid_loader = DataLoader(valid_data, batch_size=32, shuffle=False)\n", "test_loader = DataLoader(test_data, batch_size=32, shuffle=False)\n", "\n", "# Determine the number of node features\n", "num_node_features = train_data[0].x.shape[1]\n", "num_tasks = dataset.num_tasks\n", "\n", "class GNN(torch.nn.Module):\n", " def __init__(self, hidden_channels):\n", " super(GNN, self).__init__()\n", " self.conv1 = GCNConv(num_node_features, hidden_channels)\n", " self.conv2 = GCNConv(hidden_channels, hidden_channels)\n", " self.lin = torch.nn.Linear(hidden_channels, num_tasks)\n", "\n", " def forward(self, x, edge_index, batch):\n", " x = self.conv1(x, edge_index).relu()\n", " x = self.conv2(x, edge_index).relu()\n", " x = global_mean_pool(x, batch) # Global pooling\n", " x = self.lin(x)\n", " return x\n", "\n", "model = GNN(hidden_channels=64)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", "criterion = torch.nn.BCEWithLogitsLoss()\n", "\n", "# Training loop\n", "def train():\n", " model.train()\n", " for data in train_loader:\n", " optimizer.zero_grad()\n", " out = model(data.x, data.edge_index, data.batch)\n", " loss = criterion(out, data.y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "# Evaluation\n", "def evaluate(loader):\n", " model.eval()\n", " y_true = []\n", " y_pred = []\n", " for data in loader:\n", " with torch.no_grad():\n", " out = model(data.x, data.edge_index, data.batch)\n", " y_true.append(data.y.view(-1, 1).cpu())\n", " y_pred.append(out.view(-1, 1).cpu())\n", " y_true = torch.cat(y_true, dim=0).numpy()\n", " y_pred = torch.cat(y_pred, dim=0).numpy()\n", " return evaluator.eval({\"y_true\": y_true, \"y_pred\": y_pred})[\"rocauc\"]\n", "\n", "evaluator = Evaluator(name=\"ogbg-molhiv\")\n", "for epoch in range(1, 101):\n", " train()\n", " valid_rocauc = evaluate(valid_loader)\n", " print(f'Epoch: {epoch:03d}, Validation ROC-AUC: {valid_rocauc:.4f}')" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }