{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generalized Sliced-Wasserstein Flows with Neural Networks\n",
    "\n",
    "The goal of this experiment is to illustrate the effects of the Generalized Sliced-Wasserstein (GSW) and maximum Generalized Sliced-Wasserstein (max-GSW) distance, whose defining function is learned through a neural network.\n",
    "\n",
    "## Experiment details\n",
    "\n",
    "We consider the following problem:\n",
    "$$\\operatorname{min}_{p_Y} GSW_p(p_X, p_Y),$$\n",
    "\n",
    "where $p_X$ is a target distribution and $p_Y$ is the source distribution, which is initialized to the normal distribution. \n",
    "\n",
    "The optimization is solved iteratively via\n",
    "$$ \\partial_t (p_Y)_t= -\\nabla GSW_p(p_X, (p_Y)_t), ~~(p_Y)_0=\\mathcal{N}(0, (0.25)^2).$$\n",
    "\n",
    "We also consider $\\operatorname{min}_{\\mu} \\{ \\max\\text{-}GSW_p(p_X, p_Y) \\},$ and we use the same optimization scheme to solve it (with $\\max\\text{-}GSW_p$ in place of $GSW_p$). \n",
    "\n",
    "We use 5 well-known distributions as the target: the 25-Gaussians, 8-Gaussians, Swiss Roll, Half Moons and Circle distributions. \n",
    "\n",
    "The defining function is learned through a neural network. We compare different configurations: we use a multilayer perceptron of depth 1, 2 or 3. \n",
    "\n",
    "We analyze the results (i) qualitatively, by plotting samples drawn from $p_X$ and $(p_Y)_t$ at each iteration $t$ of the optimization process, and (ii) quantitatively, by computing and reporting the 2-Wasserstein distance between $p_X$ and $(p_Y)_t$ at each $t$.\n",
    "\n",
    "## Requirements\n",
    "\n",
    "* Numpy\n",
    "* Scikit-learn\n",
    "* PyTorch\n",
    "* POT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../gsw')\n",
    "\n",
    "import numpy as np\n",
    "from gswnn import GSW_NN\n",
    "from gsw_utils import w2,load_data\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.autograd import Function\n",
    "from torch.nn.parameter import Parameter\n",
    "from torch import optim\n",
    "\n",
    "from tqdm import tqdm\n",
    "from IPython import display\n",
    "import time\n",
    "import pickle \n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We choose a dataset and load it\n",
    "### The dataset name must be 'swiss_roll', 'half_moons', 'circle', '8gaussians' or '25gaussians'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = 'swiss_roll'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000  # Number of samples from p_X\n",
    "X = load_data(name=dataset_name, n_samples=N)\n",
    "X -= X.mean(dim=0)[np.newaxis,:]  # Normalization\n",
    "meanX = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show the dataset\n",
    "_, d = X.shape\n",
    "fig = plt.figure(figsize=(5,5))\n",
    "plt.scatter(X[:,0], X[:,1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We create the different folders to store the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folder = './saved_results_GSW_flows_NN'\n",
    "if not os.path.isdir(results_folder):\n",
    "    os.mkdir(results_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "foldername = os.path.join(results_folder, 'Gifs')\n",
    "if not os.path.isdir(foldername):\n",
    "    os.mkdir(foldername)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "foldername = os.path.join(results_folder, 'Gifs', dataset_name + '_Comparison_NN')\n",
    "if not os.path.isdir(foldername):\n",
    "    os.mkdir(foldername)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We solve the two optimization problems for different neural network configurations and plot the results at each step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use GPU if available, CPU otherwise\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of iterations for the optimization process\n",
    "nofiterations = 250"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the variables to store the loss (2-Wasserstein distance) for each defining function and each problem \n",
    "w2_dist = np.nan * np.zeros((nofiterations, 3))\n",
    "maxw2_dist = np.nan * np.zeros((nofiterations, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the different neural networks architectures\n",
    "depth = [1, 2, 3, 1, 2, 3]\n",
    "titles=['GSW NN - Depth=1', 'GSW NN - Depth=2', 'GSW NN - Depth=3', \n",
    "        'MaxGSW NN - Depth=1', 'MaxGSW NN - Depth=2', 'MaxGSW NN - Depth=3']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the initial distribution\n",
    "temp = np.random.normal(loc=meanX, scale=.25, size=(N,d))\n",
    "\n",
    "# Define the optimizers\n",
    "Y=list()\n",
    "optimizer=list()\n",
    "gsw=list()\n",
    "\n",
    "for i in range(6):\n",
    "    Y.append(torch.tensor(temp, dtype=torch.float, device=device, requires_grad=True))\n",
    "    optimizer.append(optim.Adam([Y[-1]], lr = 1e-2))\n",
    "    gsw.append(GSW_NN(din=2, nofprojections=1, model_depth=depth[i]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize=(15, 15))\n",
    "grid = plt.GridSpec(3, 3, wspace=0.4, hspace=0.3)\n",
    "\n",
    "for i in range(nofiterations):            \n",
    "    loss=list()\n",
    "    # We loop over the different neural networks configurations for the GSW problem\n",
    "    for k in range(3):\n",
    "        # Loss computation (here, GSW)\n",
    "        loss = gsw[k].gsw(X.to(device), Y[k].to(device))\n",
    "        \n",
    "        # Optimization step\n",
    "        optimizer[k].zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer[k].step()\n",
    "        \n",
    "        # Compute the 2-Wasserstein distance to compare the distributions\n",
    "        w2_dist[i, k] = w2(X.detach().cpu().numpy(),Y[k].detach().cpu().numpy())  \n",
    "        \n",
    "        # Plot samples from the target and the current solution\n",
    "        temp = Y[k].detach().cpu().numpy()\n",
    "        plt.subplot(grid[0, k])\n",
    "        plt.scatter(X[:, 0], X[:, 1])\n",
    "        plt.scatter(temp[:, 0], temp[:, 1], c='r') \n",
    "        plt.title(titles[k], fontsize=22)\n",
    "    \n",
    "    # We loop over the different neural networks configurations for the max-GSW problem \n",
    "    for k in range(3,6):\n",
    "        # Loss computation (here, max-GSW)\n",
    "        loss = gsw[k].max_gsw(X.to(device), Y[k].to(device), iterations=250, lr=1e-4)\n",
    "        \n",
    "        # Optimization step\n",
    "        optimizer[k].zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer[k].step() \n",
    "        \n",
    "        # Compute the 2-Wasserstein distance to compare the distributions\n",
    "        maxw2_dist[i, k-3] = w2(X.detach().cpu().numpy(), Y[k].detach().cpu().numpy())  \n",
    "        \n",
    "        # Plot samples from the target and the current solution\n",
    "        temp = Y[k].detach().cpu().numpy()\n",
    "        plt.subplot(grid[1, k-3])\n",
    "        plt.scatter(X[:, 0], X[:, 1])\n",
    "        plt.scatter(temp[:, 0], temp[:, 1],c='r') \n",
    "        plt.title(titles[k], fontsize=22)    \n",
    "    \n",
    "    # Plot the 2-Wasserstein distance\n",
    "    plt.subplot(grid[2, 0:3])\n",
    "    plt.plot(np.log10(w2_dist[:,:]), linewidth=3)\n",
    "    plt.plot(np.log10(maxw2_dist[:,:]), linewidth=3)\n",
    "    plt.title('2-Wasserstein Distance', fontsize=22)\n",
    "    plt.ylabel(r'$Log_{10}(W_2)$', fontsize=22)\n",
    "    \n",
    "    plt.legend(titles, fontsize=22, loc='lower left')\n",
    "    display.clear_output(wait=True)\n",
    "    display.display(plt.gcf()) \n",
    "    time.sleep(1e-5)    \n",
    "    \n",
    "    # Save the figure\n",
    "    fig.savefig(foldername+'/img%03d.png'%(i))\n",
    "    for k in range(3):\n",
    "        plt.subplot(grid[:, k])\n",
    "        plt.cla()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### We save the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = os.path.join(results_folder, dataset_name + '_comparison_NN_1run.pkl')\n",
    "with open(filename, 'wb') as f:\n",
    "    pickle.dump([w2_dist, maxw2_dist], f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import imageio\n",
    "from glob import glob\n",
    "from skimage.transform import resize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filenames = glob(foldername + '/*.png')\n",
    "images = []\n",
    "for filename in filenames:\n",
    "    images.append((resize(imageio.imread(filename).astype(float) / 255., (750, 750, 4)) * 255).astype('uint8'))\n",
    "imageio.mimsave(dataset_name + '_comparison_NN.gif', images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
