{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tree isomorphism classification (training)\n",
    "\n",
    "The notebook trains GNNs of various widths and depths to classify connected graphs from the *glue_tree* datasets.\n",
    "\n",
    "It contains two parts: \n",
    "1. Train a single network. The code is used for demo purposes and does not save any result. \n",
    "2. Exhaustively train GNNs (main experiment). For every trained network, the code saves relevant statistics (training loss, training error, validation error, and test error) as well as the model itself to a pickle. The code takes a *long* time to run from scratch (~2 weeks gpu time). \n",
    "3. Train only large capacity GNNs (this corresponds to Table 2 in supplementary material). Similar to the second experiment, the results have been already precomputed and stored for easy examination.\n",
    "\n",
    "The results are visualized by the `gnn_gluedtree_visualize.ipynb` notebook.\n",
    "\n",
    "The code accompanies the paper *How hard is to distinguish graphs with graph neural networks*, submitted to NeurIPS 2020.\n",
    "\n",
    "requirements: numpy, networkx, scipy, pickle, torch, torch_geometric\n",
    "\n",
    "\n",
    "The anonymous author\n",
    "\n",
    "8 June 2020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change this to point to the location of the main folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datadir = 'supplementary-root-folder'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "import scipy as sp\n",
    "import math\n",
    "import pickle \n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import Sequential, Linear, ReLU\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "from torch_geometric.nn import GINConv, global_add_pool\n",
    "\n",
    "from dataset_tree import *\n",
    "from models import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Demo run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_all         = [4,5,6,7,8,9,10,11] \n",
    "n_samples_all = [10000,10000,40000,40000,40000,40000,40000,100000]\n",
    "pytorch       = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_idx = 3 # choose from 0 to 7\n",
    "v, n_samples = v_all[v_idx], n_samples_all[v_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not pytorch :    \n",
    "    with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}.pickle'), 'rb') as f:\n",
    "        dataset = pickle.load(f)[0]     \n",
    "\n",
    "    dataset_torch = glued_dataset_to_torch(dataset, unique_ids=True)    \n",
    "    n_classes = dataset[0]['label'].shape[0]\n",
    "\n",
    "else:\n",
    "    with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}_pytorch.pickle'), 'rb') as f:\n",
    "        dataset_torch = pickle.load(f)[0]     \n",
    "    n_classes = max([datum.y for datum in dataset_torch]).numpy()[0]+1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the GNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "try:\n",
    "    del model, optimizer\n",
    "    torch.cuda.empty_cache()     \n",
    "except:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "width, depth = 16, 6\n",
    "n_epochs = 1000\n",
    "\n",
    "n_training_samples, n_valid_samples = np.int(0.9*n_samples),np.int(0.05*n_samples) \n",
    "train_loader = DataLoader(dataset_torch[:n_training_samples], batch_size=512, shuffle=True)    \n",
    "valid_loader = DataLoader(dataset_torch[n_training_samples:n_training_samples+n_valid_samples], batch_size=2000, shuffle=True) \n",
    "test_loader  = DataLoader(dataset_torch[n_training_samples+n_valid_samples:], batch_size=2000, shuffle=True) \n",
    "\n",
    "n_features_node = dataset_torch[0].x.shape[1]\n",
    "\n",
    "model     = NetGIN(n_classes=n_classes, n_features_node=n_features_node, width=width, depth=depth).to(device)                            \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "t_start = time.time()\n",
    "for iepoch, epoch in enumerate(range(1,n_epochs+1)):\n",
    "    loss = train(model, train_loader, device, epoch, optimizer)\n",
    "    if epoch % 10 == 1:                                \n",
    "        acc_train  = test(model, train_loader, device)                        \n",
    "        acc_valid  = test(model, valid_loader, device)                                \n",
    "        print(f'{epoch:5d} | Loss: {loss:2.5f} | Train Acc: {acc_train:2.5f} | Valid Acc: {acc_valid:2.5f} | LR: {optimizer.param_groups[0][\"lr\"]:.6f} | {time.time() - t_start:4.3f} sec')                                                        \n",
    "\n",
    "acc_train = test(model, train_loader, device)\n",
    "acc_valid = test(model, valid_loader, device)            \n",
    "acc_test  = test(model, test_loader,  device) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Exhaustively test MPNNs of different capacities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "try:\n",
    "    del model, optimizer\n",
    "    torch.cuda.empty_cache()     \n",
    "except:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_all         = [4,5,6,7,8,9,10,11]\n",
    "n_samples_all = [10000,10000,40000,40000,40000,40000,40000,100000]\n",
    "n_epochs      = 1000\n",
    "exp_name      = 'main'\n",
    "pytorch_all   = [True,True,True,True,True,True,True,True]\n",
    "\n",
    "width_all     = [1,2,4,8,16]\n",
    "depth_all     = [2,3,4,5,6,7,8]       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "         \n",
    "for v_idx, v in enumerate(v_all):\n",
    "\n",
    "    n_samples = n_samples_all[v_idx]    \n",
    "    if not pytorch_all[v_idx]:    \n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}.pickle'), 'rb') as f:\n",
    "            dataset = pickle.load(f)[0]     \n",
    "            dataset_torch = glued_dataset_to_torch(dataset, unique_ids=True)    \n",
    "    else:\n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}_pytorch.pickle'), 'rb') as f:\n",
    "            dataset_torch = pickle.load(f)[0]     \n",
    "    \n",
    "    n_training_samples, n_valid_samples = np.int(0.9*n_samples),np.int(0.05*n_samples) \n",
    "    train_loader = DataLoader(dataset_torch[:n_training_samples], batch_size=512, shuffle=True)    \n",
    "    valid_loader = DataLoader(dataset_torch[n_training_samples:n_training_samples+n_valid_samples], batch_size=1000, shuffle=True) \n",
    "    test_loader  = DataLoader(dataset_torch[n_training_samples+n_valid_samples:], batch_size=1000, shuffle=True) \n",
    "\n",
    "    n_classes = max([datum.y for datum in dataset_torch]).numpy()[0]+1\n",
    "    n_features_node = dataset_torch[0].x.shape[1]\n",
    "    \n",
    "    print('\\n==============================================================')\n",
    "    print(f'v: {v}')\n",
    "    print('==============================================================')\n",
    "    \n",
    "    for width_idx, width in enumerate(width_all):        \n",
    "        for depth_idx, depth in enumerate(depth_all):\n",
    "\n",
    "            data_loss      = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_train = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_valid = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_test  = np.nan\n",
    "            \n",
    "            print(f'\\n## width: {width}, depth: {depth} ##')\n",
    "        \n",
    "            model     = NetGIN_readout(n_classes=n_classes, n_features_node=n_features_node, width=width, depth=depth).to(device)                            \n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "            t_start = time.time()        \n",
    "            for iepoch, epoch in enumerate(range(1,n_epochs+1)):\n",
    "                \n",
    "                loss = train(model, train_loader, device, epoch, optimizer)\n",
    "                data_loss[iepoch] = loss\n",
    "                \n",
    "                if epoch % 20 == 1:                                \n",
    "                    acc_train  = test(model, train_loader, device)                        \n",
    "                    acc_valid  = test(model, valid_loader, device)                                                        \n",
    "                    data_acc_train[iepoch] = acc_train\n",
    "                    data_acc_valid[iepoch] = acc_valid\n",
    "\n",
    "                    print(f'{epoch:5d} | Loss: {loss:2.4f} | Train Acc: {acc_train:2.4f} | Valid Acc: {acc_valid:2.4f} | LR: {optimizer.param_groups[0][\"lr\"]:.5f} | {(time.time() - t_start)/epoch:2.3f} sec/epoch')                                                        \n",
    "\n",
    "                    # early stopping\n",
    "                    if acc_valid == 1: break\n",
    "\n",
    "            acc_train = test(model, train_loader, device)\n",
    "            acc_valid = test(model, valid_loader, device)            \n",
    "            acc_test  = test(model, test_loader,  device)                        \n",
    "    \n",
    "            data_acc_train[iepoch] = acc_train\n",
    "            data_acc_valid[iepoch] = acc_valid\n",
    "            data_acc_test = acc_test\n",
    "    \n",
    "            print(f' loss: {loss}, train_acc: {acc_train}, valid_acc: {acc_valid}, test_acc {acc_test}')\n",
    "\n",
    "            filename = f'gluedtree_{exp_name}_v:{v}_s:{n_samples}_w:{width}_d:{depth}.pickle'\n",
    "            with open(os.path.join(datadir, f'results/{filename}'), 'wb') as f:\n",
    "                pickle.dump([model, data_loss, data_acc_train, data_acc_valid, data_acc_test], f) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": false
   },
   "source": [
    "# 3. How well do large capacity MPNN fair? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_all         = [4,5,6,7,8,9,10,11]\n",
    "n_samples_all = [10000,10000,40000,40000,40000,40000,40000,100000]\n",
    "n_epochs      = 2000\n",
    "exp_name      = 'same-capacity'\n",
    "pytorch_all   = [True,True,True,True,True,True,True,True]\n",
    "\n",
    "width_all     = [32]\n",
    "depth_all     = [10]    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for v_idx, v in enumerate(v_all):\n",
    "\n",
    "    n_samples = n_samples_all[v_idx]    \n",
    "    if not pytorch_all[v_idx]:    \n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}.pickle'), 'rb') as f:\n",
    "            dataset = pickle.load(f)[0]     \n",
    "            dataset_torch = glued_dataset_to_torch(dataset, unique_ids=True)    \n",
    "    else:\n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}_pytorch.pickle'), 'rb') as f:\n",
    "            dataset_torch = pickle.load(f)[0]     \n",
    "    \n",
    "    n_training_samples, n_valid_samples = np.int(0.9*n_samples),np.int(0.05*n_samples) \n",
    "    train_loader = DataLoader(dataset_torch[:n_training_samples], batch_size=512, shuffle=True)    \n",
    "    valid_loader = DataLoader(dataset_torch[n_training_samples:n_training_samples+n_valid_samples], batch_size=1000, shuffle=True) \n",
    "    test_loader  = DataLoader(dataset_torch[n_training_samples+n_valid_samples:], batch_size=1000, shuffle=True) \n",
    "\n",
    "    n_classes = max([datum.y for datum in dataset_torch]).numpy()[0]+1\n",
    "    n_features_node = dataset_torch[0].x.shape[1]\n",
    "    \n",
    "    print('\\n==============================================================')\n",
    "    print(f'v: {v}')\n",
    "    print('==============================================================')\n",
    "    \n",
    "    for width_idx, width in enumerate(width_all):        \n",
    "        for depth_idx, depth in enumerate(depth_all):\n",
    "\n",
    "            data_loss      = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_train = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_valid = np.zeros((n_epochs)) * np.nan\n",
    "            data_acc_test  = np.nan\n",
    "            \n",
    "            print(f'\\n## width: {width}, depth: {depth} ##')\n",
    "        \n",
    "            model     = NetGIN(n_classes=n_classes, n_features_node=n_features_node, width=width, depth=depth).to(device)                            \n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "            t_start = time.time()        \n",
    "            for iepoch, epoch in enumerate(range(1,n_epochs+1)):\n",
    "                \n",
    "                loss = train(model, train_loader, device, epoch, optimizer)\n",
    "                data_loss[iepoch] = loss\n",
    "                \n",
    "                if epoch % 20 == 1:                                \n",
    "                    acc_train  = test(model, train_loader, device)                        \n",
    "                    acc_valid  = test(model, valid_loader, device)                                                        \n",
    "                    data_acc_train[iepoch] = acc_train\n",
    "                    data_acc_valid[iepoch] = acc_valid\n",
    "\n",
    "                    print(f'{epoch:5d} | Loss: {loss:2.4f} | Train Acc: {acc_train:2.4f} | Valid Acc: {acc_valid:2.4f} | LR: {optimizer.param_groups[0][\"lr\"]:.5f} | {(time.time() - t_start)/epoch:2.3f} sec/epoch')                                                        \n",
    "\n",
    "                    # early stopping\n",
    "                    if acc_valid == 1: break\n",
    "\n",
    "            acc_train = test(model, train_loader, device)\n",
    "            acc_valid = test(model, valid_loader, device)            \n",
    "            acc_test  = test(model, test_loader,  device)                        \n",
    "    \n",
    "            data_acc_train[iepoch] = acc_train\n",
    "            data_acc_valid[iepoch] = acc_valid\n",
    "            data_acc_test = acc_test\n",
    "    \n",
    "            print(f' loss: {loss}, train_acc: {acc_train}, valid_acc: {acc_valid}, test_acc {acc_test}')\n",
    "\n",
    "            filename = f'gluedtree_{exp_name}_v:{v}_s:{n_samples}_w:{width}_d:{depth}.pickle'\n",
    "            with open(os.path.join(datadir, f'results/{filename}'), 'wb') as f:\n",
    "                pickle.dump([model, data_loss, data_acc_train, data_acc_valid, data_acc_test], f) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
