{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Compare GAN Modules Colab",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [
        "7N-eCy31umdo"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "7N-eCy31umdo",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2018 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "metadata": {
        "id": "X_lFsiKEnM33",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "65PLbu3MviWn",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Compare GAN models Colab\n",
        "\n",
        "This Colab notebook shows how to use a collection of pre-trained generative adversarial network models (GANs) for CIFAR10, CelebA HQ (128x128) and LSUN bedroom datasets to generate images. Random vectors are fed into the latent space to generate RGB images using the pre-trained generators.\n",
        "\n",
        "For the details of the setup, please refer to our paper: [The GAN Landscape: Losses, Architectures, Regularization, and Normalization](https://arxiv.org/abs/1807.04720).\n",
        "\n",
        "The code used to train these models is available at: https://github.com/google/compare_gan\n",
        "\n",
        "These models are available as [TensorFlow Hub Modules](https://tfhub.dev).\n",
        "\n",
        "<table align=\"left\">\n",
        "<td>\n",
        "  <a target=\"_blank\"  href=\"https://github.com/google/compare_gan/blob/master/compare_gan/src/tfhub_models.ipynb\">\n",
        "    <img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View notebook source on GitHub</a>\n",
        "</td></table>"
      ]
    },
    {
      "metadata": {
        "id": "1A9p20T2uxKO",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## View the available GAN modules\n",
        "\n",
        "Run the cells and select the GAN module to use from the table below."
      ]
    },
    {
      "metadata": {
        "id": "Uj7C7cE_v7Ev",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "cellView": "form"
      },
      "cell_type": "code",
      "source": [
        "#@title Imports, set up, and helper functions\n",
        "\n",
        "from google.colab import output\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "\n",
        "tf.logging.set_verbosity(tf.logging.ERROR)\n",
        "\n",
        "\n",
        "# Get the module metadata for the GANs as a pandas DataFrame.\n",
        "module_metadata_dict = {'dataset': ['CelebA HQ (128x128)', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'LSUN Bedroom', 'CelebA HQ (128x128)', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'LSUN Bedroom', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'CIFAR10', 'CIFAR10', 'CIFAR10', 'CIFAR10', 'CIFAR10'], 'penalty': ['-', '-', '-', '-', '-', '-', '-', '-', 'DRAGAN (lambda=1.000)', 'WGAN (lambda=0.145)', '-', '-', '-', '-', 'WGAN (lambda=1.000)'], 'architecture': ['ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR'], 'beta1': ['0.375', '0.500', '0.585', '0.195', '0.500', '0.500', '0.500', '0.102', '0.500', '0.711', '0.500', '0.500', '0.500', '0.500', '0.500'], 'beta2': ['0.998', '0.999', '0.990', '0.882', '0.999', '0.999', '0.999', '0.998', '0.900', '0.979', '0.999', '0.999', '0.999', '0.999', '0.999'], 'module_url': ['https://tfhub.dev/google/compare_gan/model_1_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_2_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_3_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_4_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_5_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_6_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_7_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_8_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_9_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_10_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_11_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_12_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_13_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_14_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_15_cifar10_resnet_cifar/1'], 'disc_iters': [1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 5, 5, 5, 5, 5], 'model': ['Non-saturating GAN', 'Non-saturating GAN', 'Least-squares GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Least-squares GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN'], 'inception_score': ['2.38', '2.59', '4.23', '4.10', '2.38', '2.54', '3.64', '3.58', '2.34', '3.92', '7.57', '7.47', '7.74', '7.74', '7.70'], 'disc_norm': ['none', 'none', 'none', 'none', 'layer_norm', 'layer_norm', 'spectral_norm', 'spectral_norm', 'layer_norm', 'layer_norm', 'none', 'none', 'spectral_norm', 'spectral_norm', 'spectral_norm'], 'fid': ['34.29', '35.85', '102.74', '112.92', '30.02', '32.05', '41.60', '42.51', '29.13', '40.36', '28.12', '30.08', '22.91', '23.22', '22.73'], 'ms_ssim_score': ['0.32', '0.29', 'N/A', 'N/A', '0.29', '0.28', 'N/A', 'N/A', '0.30', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A'], 'learning_rate': ['3.381e-05', '1.000e-04', '3.220e-05', '1.927e-05', '1.000e-04', '1.000e-04', '2.000e-04', '2.851e-04', '1.000e-04', '1.281e-04', '2.000e-04', '1.000e-04', '2.000e-04', '2.000e-04', '2.000e-04']}\n",
        "MODULE_METADATA = pd.DataFrame.from_dict(module_metadata_dict)\n",
        "\n",
        "# To start, select the module with the lowest FID score.\n",
        "MIN_FID_MODULE = MODULE_METADATA.loc[\n",
        "    MODULE_METADATA['fid'].astype(float).idxmin()]\n",
        "\n",
        "SELECTED_MODULE = MIN_FID_MODULE['module_url']\n",
        "SELECTED_MODULE_DATASET = MIN_FID_MODULE['dataset']\n",
        "\n",
        "\n",
        "# Display multiple images in the same figure.\n",
        "def display_images(images, captions=None):\n",
        "  batch_size, dim1, dim2, channels = images.shape\n",
        "  num_horizontally = 8\n",
        "  \n",
        "  # Use a smaller figure size for the CIFAR10 images\n",
        "  figsize = (20, 20) if dim1 > 32 else (10, 10)\n",
        "  f, axes = plt.subplots(\n",
        "      len(images) // num_horizontally, num_horizontally, figsize=figsize)\n",
        "  for i in range(len(images)):\n",
        "    axes[i // num_horizontally, i % num_horizontally].axis(\"off\")\n",
        "    if captions is not None:\n",
        "      axes[i // num_horizontally, i % num_horizontally].text(0, -3, captions[i])\n",
        "    axes[i // num_horizontally, i % num_horizontally].imshow(images[i])\n",
        "  f.tight_layout()\n",
        "  \n",
        "\n",
        "# Show the HTML for the module table\n",
        "class ShowModuleTable(object):\n",
        "  def __init__(self, callback):\n",
        "    self._callback = callback\n",
        "\n",
        "  def _repr_html_(self):\n",
        "    # Set up the template with some nice CSS.\n",
        "    template = \"\"\"\n",
        "    <style>\n",
        "       table {\n",
        "         font-size: 15px;\n",
        "         font-family: Inconsolata, monospace;\n",
        "         border-collapse: collapse;\n",
        "         border: 1px solid #444444;\n",
        "       }\n",
        "       th {\n",
        "         font-size: 18px;\n",
        "         background-color: #DDDDDD;\n",
        "         border: 1px solid #AAAAAA;\n",
        "         white-space: nowrap;\n",
        "       }\n",
        "       tr {\n",
        "         cursor: pointer;\n",
        "         white-space: nowrap;\n",
        "       }\n",
        "       td {\n",
        "         padding: 6px 6px 6px 6px;\n",
        "         border: 1px solid #AAAAAA;\n",
        "       }\n",
        "      .selected-row {\n",
        "        font-weight: bold;\n",
        "        background-color: #B0BED9;\n",
        "      }\n",
        "    </style>\n",
        "    <table>\"\"\"\n",
        "    \n",
        "    # Set up the headers with nicely readable names\n",
        "    table_headers = [\n",
        "      ('dataset', 'Dataset'),\n",
        "      ('architecture', 'Architecture'),\n",
        "      ('fid', 'FID'),\n",
        "      ('inception_score', 'IS'),\n",
        "      ('ms_ssim_score', 'MS-SSIM'),\n",
        "      ('model', 'Model'),\n",
        "      ('learning_rate', 'Learning rate'),\n",
        "      ('beta1', '&beta;<sub>1</sub>'),\n",
        "      ('beta2', '&beta;<sub>2</sub>'),\n",
        "      ('disc_iters', 'n<sub>disc</sub>'),\n",
        "      ('disc_norm', 'Disc norm'),\n",
        "      ('penalty', 'Penalty'),\n",
        "      ('module_url', 'Module name'),\n",
        "    ]\n",
        "    header_template = \"<tr>\"\n",
        "    for _, header_name in table_headers:\n",
        "      header_template += \"<th>\" + header_name + \"</th>\"\n",
        "    header_template += \"</tr>\"\n",
        "    template += header_template\n",
        "    \n",
        "    for i, (_, row) in enumerate(MODULE_METADATA.iterrows()):\n",
        "      uuid = \"row-%s\" % i\n",
        "      \n",
        "      # Reister the callback for every row.\n",
        "      output.register_callback(uuid, self._callback)\n",
        "      \n",
        "      # By default select the module with the min FID.\n",
        "      selected_class = \"\"\n",
        "      if row['module_url'] == MIN_FID_MODULE['module_url']:\n",
        "        selected_class = \"class=\\\"selected-row\\\"\"\n",
        "\n",
        "      # Get the metadata for each row.\n",
        "      row_template = \"<tr id=\\\"\" + uuid + \"\\\"\" + selected_class + \">\"\n",
        "      for key, _ in table_headers:\n",
        "        row_template += \"<td>\" + str(row[key]) + \"</td>\"\n",
        "      row_template += \"</tr>\"\n",
        "      template += row_template\n",
        "      \n",
        "    # Add the onclick handlers for the table rows.\n",
        "    template += \"\"\"\n",
        "      </table>\n",
        "      <script>\"\"\"\n",
        "    \n",
        "    for i, (_, row) in enumerate(MODULE_METADATA.iterrows()):\n",
        "      uuid = \"row-%s\" % i\n",
        "      m = row['module_url']\n",
        "      d = row['dataset']\n",
        "      template += \"\"\"\n",
        "        document.querySelector(\\\"#\"\"\" + uuid + \"\"\"\\\").onclick = function() {\n",
        "          google.colab.kernel.invokeFunction('\"\"\" + uuid + \"\"\"', ['\"\"\" + m +\"\"\"', '\"\"\" + d + \"\"\"'], {});\n",
        "          var selected = document.getElementsByClassName('selected-row');\n",
        "          for (var i = 0; i < selected.length; i++) {\n",
        "            selected[i].classList.remove('selected-row');\n",
        "          }\n",
        "          this.classList.toggle(\"selected-row\");\n",
        "          e.preventDefault();\n",
        "        };\n",
        "        \"\"\"\n",
        "    template += \"\"\"</script>\"\"\"\n",
        "    return template\n",
        "\n",
        "\n",
        "def set_selected_module(module_name, dataset):\n",
        "  # Assign the selected module and dataset to the global variables\n",
        "  global SELECTED_MODULE\n",
        "  SELECTED_MODULE = module_name\n",
        "  global SELECTED_MODULE_DATASET\n",
        "  SELECTED_MODULE_DATASET = dataset\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ziU4PhGh283a",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "cellView": "form"
      },
      "cell_type": "code",
      "source": [
        "#@title Run this cell and select which GAN module to use below\n",
        "\n",
        "ShowModuleTable(set_selected_module)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "sgzX499wqPTv",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Generate images from the selected module"
      ]
    },
    {
      "metadata": {
        "id": "ZP0P3TkGoBAx",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        }
      },
      "cell_type": "code",
      "source": [
        "assert SELECTED_MODULE is not None and SELECTED_MODULE_DATASET is not None, \\\n",
        "  'You must run the above cell and select a module from the table to generate images.'\n",
        "\n",
        "print('Using module: \"%s\"' % SELECTED_MODULE)\n",
        "print('Generating images like dataset: \"%s\"' % SELECTED_MODULE_DATASET)\n",
        "\n",
        "# The generator expects a batch of 64 vectors of size 128\n",
        "batch_size = 64\n",
        "z_dim = 128\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  # Load the selected module\n",
        "  gan = hub.Module(SELECTED_MODULE)\n",
        "  z_input = tf.placeholder(dtype=tf.float32, shape=(batch_size, z_dim))\n",
        "  image_output = gan(z_input, signature=\"generator\") \n",
        "  \n",
        "  with tf.train.MonitoredSession() as session:\n",
        "    # Generate 64 random vectors as input to the latent space to generate images\n",
        "    z_values = np.random.uniform(-1, 1, size=(batch_size, z_dim))\n",
        "    images = session.run(image_output, feed_dict={z_input: z_values})\n",
        "\n",
        "    # View the resulting images\n",
        "    display_images(images)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}