{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset generator: glued tree \n",
    "\n",
    "The notebook was used to generate the glue_tree dataset used in the tree isomorphism experiment. \n",
    "\n",
    "The notebook contains:\n",
    "1. Demo code for generating and plotting a random tree and a random glued tree\n",
    "2. Code for generating a networkx dataset (this takes a lot of space on disc and is only used with small samples sizes)\n",
    "3. Code for generating a pytorch geometric dataset (this is more memory efficient)\n",
    "4. Calculation of dataset statistics\n",
    "\n",
    "The datasets have been already precomputed and stored on file. \n",
    "\n",
    "The code accompanies the paper *How hard is to distinguish graphs with graph neural networks*, submitted to NeurIPS 2020.\n",
    "\n",
    "requirements: numpy, networkx, scipy, pickle, torch, torch_geometric, sklearn\n",
    "\n",
    "\n",
    "The anonymous author\n",
    "\n",
    "8 June 2020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder\n",
    "from dataset_tree import *\n",
    "import os\n",
    "import pickle \n",
    "import scipy as sp\n",
    "import math\n",
    "import pickle \n",
    "import time\n",
    "import pylab\n",
    "\n",
    "import torch\n",
    "from torch_geometric.data import Data, DataLoader\n",
    "\n",
    "from timeit import timeit\n",
    "from networkx.generators.nonisomorphic_trees import nonisomorphic_trees"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change this to point to the location of the main folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datadir = 'supplementary-root-folder'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random trees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_tree = nx.generators.trees.random_tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes = 20\n",
    "n_trees = 10\n",
    "\n",
    "fig = plt.figure(figsize=(4*n_trees,4))\n",
    "for i in range(n_trees):\n",
    "    ax = fig.add_subplot(1,n_trees,i+1)\n",
    "    nx.draw_networkx(random_tree(n_nodes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random glued tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_universe = list(nonisomorphic_trees(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_universe = len(tree_universe)\n",
    "\n",
    "#  sample two trees\n",
    "tree1_idx, tree2_idx = 0,1\n",
    "tree1 = tree_universe[tree1_idx]\n",
    "tree2 = tree_universe[tree2_idx]\n",
    "\n",
    "# glue them \n",
    "graph = glue_trees(tree1, tree2)\n",
    "\n",
    "# randomly permute nodes\n",
    "n_tree_nodes = tree1.number_of_nodes()\n",
    "p = np.append(np.random.permutation(n_tree_nodes), n_tree_nodes+np.random.permutation(n_tree_nodes))\n",
    "A = nx.adjacency_matrix(graph)[p,:][:,p].todense()\n",
    "graph_relabeled = nx.from_numpy_matrix(A)\n",
    "for node in graph_relabeled.nodes:    \n",
    "    graph_relabeled.nodes[node]['owner'] = graph.nodes[p[node]]['owner']\n",
    "    graph_relabeled.nodes[node]['label'] = str(node % n_tree_nodes+1) #graph.nodes[p[node]]['label']\n",
    "    graph_relabeled.nodes[node]['color'] = graph.nodes[p[node]]['color']\n",
    "    graph_relabeled.nodes[node]['pos'] = graph.nodes[p[node]]['pos']\n",
    "    graph_relabeled.nodes[node]['is_root'] = graph.nodes[p[node]]['label']=='1'\n",
    "graph = graph_relabeled\n",
    "\n",
    "# compute the isomorphism class\n",
    "isomorphism_class = 0\n",
    "if tree1_idx == tree2_idx: \n",
    "    isomorphism_class = str(tree1_idx)\n",
    "else: \n",
    "    if tree1_idx < tree2_idx: \n",
    "        isomorphism_class = str(tree1_idx) + str(tree2_idx)\n",
    "    else:\n",
    "        isomorphism_class = str(tree2_idx) + str(tree1_idx)\n",
    "                \n",
    "draw_glued_tree(graph_relabeled, figsize=(10,7))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building an nx.graph list dataset \n",
    "This keeps all information but requires a lot of memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 10000 # choose how many samples we need\n",
    "save      = False # choose whether to save the data on file (possibly overwritting current data)\n",
    "\n",
    "for v in [4]: # in [4,5,6,7,8,9,10,11]\n",
    "\n",
    "    # constuct the tree universe\n",
    "    tree_universe = list(nonisomorphic_trees(v))\n",
    "    n_universe    = len(tree_universe)\n",
    "\n",
    "    print(f'The tree-universe contains {n_universe} trees of {v} nodes')\n",
    "\n",
    "    dataset = []\n",
    "    for i in range(n_samples):\n",
    "        graph, isomorphism_class, tree1_idx, tree2_idx = sample_glued_tree(tree_universe)        \n",
    "        datum = dict({'graph': graph, 'isomorphism_class': isomorphism_class, 'tree_idx': (tree1_idx, tree2_idx) })\n",
    "        dataset.append(datum) \n",
    "\n",
    "    # encode the labels categorically\n",
    "    labels = [datum['isomorphism_class'] for datum in dataset] \n",
    "    enc = OneHotEncoder(categories='auto')\n",
    "    labels_encoded = enc.fit_transform([[l] for l in labels])\n",
    "    labels_encoded = np.array(labels_encoded.todense())\n",
    "\n",
    "    # add categorical labels to the dataset\n",
    "    for idx in range(n_samples): \n",
    "        dataset[idx]['label'] = labels_encoded[idx,:]\n",
    "\n",
    "    print(f'The {len(labels)} graphs sampled belong to {len(np.unique(np.array(labels)))} isomorphic classes')\n",
    "\n",
    "    # save it to file\n",
    "    if save:\n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}.pickle'), 'wb') as f: \n",
    "            pickle.dump([dataset], f) \n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a pytorch dataset directly\n",
    "I used this version in the paper's experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 10000 # choose how many samples we need\n",
    "save      = False # choose whether to save the data on file (possibly overwritting current data)\n",
    "\n",
    "for v in [4]: # in [4,5,6,7,8,9,10,11]\n",
    "    \n",
    "    # constuct the tree universe\n",
    "    universe = list(nonisomorphic_trees(v))\n",
    "    n_universe    = len(universe)\n",
    "\n",
    "    print(f'The tree-universe contains {n_universe} trees of {v} nodes')\n",
    "\n",
    "    dataset_torch, labels = [], []\n",
    "    for i in range(n_samples):\n",
    "        \n",
    "        graph, label, _, _ = sample_glued_tree(universe)        \n",
    "        \n",
    "        edge_index = np.reshape(np.array([([edge[0], edge[1], edge[1], edge[0]]) for edge in nx.to_edgelist(graph)]),(-1,2))\n",
    "        edge_index = torch.tensor(edge_index.T, dtype=torch.long)\n",
    "\n",
    "        x = np.zeros((graph.number_of_nodes(), 2+graph.number_of_nodes()), dtype=np.float)\n",
    "        for node in graph.nodes:\n",
    "\n",
    "            # reveal the owner of each node\n",
    "            if graph.nodes[node]['owner'] == 'alice':\n",
    "                x[node,0] = 0\n",
    "            else: \n",
    "                x[node,0] = 1\n",
    "\n",
    "            # reveal the roots\n",
    "            if graph.nodes[node]['is_root'] and graph.nodes[node]['owner'] == 'alice': x[node,1] = 1\n",
    "\n",
    "            x[node,2:] = np.eye(graph.number_of_nodes())[node]\n",
    "\n",
    "        x = torch.tensor(x, dtype=torch.float) \n",
    "        y = torch.tensor([0], dtype=torch.long)\n",
    "\n",
    "        data = Data(x=x, edge_index=edge_index, edge_attr=None, y=y)\n",
    "        \n",
    "        dataset_torch.append(data)\n",
    "        labels.append(label)\n",
    "        \n",
    "        if i % int(n_samples/10) == 0: print(f'{100*(i/n_samples):3.0f}%')\n",
    "\n",
    "    # encode the labels\n",
    "    enc = OrdinalEncoder(categories='auto')\n",
    "    labels_encoded = enc.fit_transform([[l] for l in labels])\n",
    "\n",
    "    # add categorical labels to the dataset\n",
    "    for idx in range(n_samples):             \n",
    "        dataset_torch[idx].y = torch.tensor(labels_encoded[idx], dtype=torch.long) \n",
    "        \n",
    "    n_classes = len(np.unique(np.array(labels)))\n",
    "    print(f'The {len(labels)} graphs sampled belong to {n_classes} isomorphic classes')\n",
    "\n",
    "    # save it to file\n",
    "    if save:\n",
    "        print('saving to disc..')\n",
    "        with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples}_pytorch.pickle'), 'wb') as f: \n",
    "            pickle.dump([dataset_torch], f) \n",
    "            print('saved.')\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric import utils\n",
    "from general_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_all         = [4,5,6,7,8,9,10,11] \n",
    "n_nodes_all   = [2*n for n in v_all]\n",
    "n_samples_all = [10000,10000,10000,10000,40000,40000,40000,100000]\n",
    "pytorch       = [True,True,True,True,True,True,True,True]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_classes  = np.zeros(len(n_nodes_all))\n",
    "n_universe = np.zeros(len(n_nodes_all))\n",
    "stats_deg  = []\n",
    "stats_diam = []\n",
    "\n",
    "for v_idx, v in enumerate(v_all):\n",
    "            \n",
    "    # constuct the subgraph universe\n",
    "    universe = list(nonisomorphic_trees(v))\n",
    "    n_universe[v_idx] = len(universe)    \n",
    "    n_classes[v_idx] = n_universe[v_idx]*(n_universe[v_idx]+1)/2\n",
    "    \n",
    "    try: \n",
    "        if not pytorch[v_idx] :    \n",
    "            with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples_all[v_idx]}.pickle'), 'rb') as f:\n",
    "                dataset = pickle.load(f)[0]     \n",
    "\n",
    "            dataset_torch = glued_dataset_to_torch(dataset, unique_ids=True)    \n",
    "\n",
    "        else:\n",
    "            with open(os.path.join(datadir, f'datasets/gluedtree_v:{v}_s:{n_samples_all[v_idx]}_pytorch.pickle'), 'rb') as f:\n",
    "                dataset_torch = pickle.load(f)[0]     \n",
    "    except: \n",
    "        print('error:', v)\n",
    "        continue\n",
    "        \n",
    "    diameters, degrees = [], []\n",
    "    for data in dataset_torch: \n",
    "        graph = utils.convert.to_networkx(data)     \n",
    "        diameters.append(diameter(graph))\n",
    "        degrees.append(degree_mean_fn(graph))\n",
    "    stats_deg.append(degrees)\n",
    "    stats_diam.append(diameters)\n",
    "    \n",
    "    print(f'n = {n_nodes_all[v_idx]} | {n_classes[v_idx]} isomorphism classes | deg: {np.mean(degrees)} | diam: {np.mean(diameters)}')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print([f'{c:4.0f}' for c in n_nodes_all])\n",
    "print([f'{c:4.0f}' for c in n_classes])\n",
    "print([f'{np.mean(degrees):2.1f}' for degrees in stats_deg])\n",
    "print([f'{np.mean(diameters):2.1f}' for diameters in stats_diam])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
