{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "organic-horizon",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.autograd import grad, Variable\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import copy\n",
    "import scipy as sp\n",
    "from scipy import stats\n",
    "from sklearn import metrics\n",
    "import sys\n",
    "import ot\n",
    "import gwot\n",
    "from gwot import models, sim, ts, util\n",
    "import gwot.bridgesampling as bs\n",
    "import dcor\n",
    "\n",
    "sys.path.append(\"fig1_batch/\")\n",
    "import importlib\n",
    "import models\n",
    "importlib.reload(models)\n",
    "import random\n",
    "import model_fig1 as model_sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "needed-appliance",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"8\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads\n",
    "torch.set_num_threads(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "institutional-puppy",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "integral-flesh",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLT_CELL = 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "standard-gossip",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seed\n",
    "SRAND = 0\n",
    "torch.manual_seed(SRAND)\n",
    "np.random.seed(SRAND)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "removable-banks",
   "metadata": {},
   "outputs": [],
   "source": [
    "M = 100\n",
    "N = 64\n",
    "sim = gwot.sim.Simulation(V = model_sim.Psi, dV = model_sim.dPsi, birth_death = False, \n",
    "                          N = np.full(model_sim.T, N),\n",
    "                          T = model_sim.T, \n",
    "                          d = model_sim.dim, \n",
    "                          D = model_sim.D, \n",
    "                          t_final = model_sim.t_final, \n",
    "                          ic_func = model_sim.ic_func, \n",
    "                          pool = None)\n",
    "sim.sample(steps_scale = int(model_sim.sim_steps/sim.T));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fresh-start",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(sim.x[:, 0], sim.x[:, 1], alpha = 0.25, c = sim.t_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "wireless-buffer",
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "necessary-poland",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*1.0,\n",
    "                        torch.tensor(sim.x, device = device), \n",
    "                        torch.tensor(sim.t_idx, device = device), \n",
    "                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = None, M = M,\n",
    "                        lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,\n",
    "                        branching_rate_fn = model_sim.branching_rate,\n",
    "                        sinkhorn_iters = 250, device = device, warm_start = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "proved-speech",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = models.optimize(model, n_iter = 2500, eta_final = 0.25, tau_final = model_sim.D, sigma_final = 0.5, temp_init = 1.0, temp_ratio = 1.0, N = M, dim = model_sim.dim, tloss = model, print_interval = 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "early-event",
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "disturbed-window",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_anneal = models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*1.0,\n",
    "                        torch.tensor(sim.x, device = device), \n",
    "                        torch.tensor(sim.t_idx, device = device), \n",
    "                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = None, M = M,\n",
    "                        lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,\n",
    "                        branching_rate_fn = model_sim.branching_rate,\n",
    "                        sinkhorn_iters = 250, device = device, warm_start = True)\n",
    "\n",
    "output_anneal = models.optimize(model_anneal, n_iter = 2500, eta_final = 0.25, tau_final = model_sim.D, sigma_final = 0.5, temp_init = 5, temp_ratio = (1/5)**(1/500), N = M, dim = model_sim.dim, tloss = model_anneal, print_interval = 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "comparative-think",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_anneal[2][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "later-intent",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "from tqdm import tqdm\n",
    "primal_anneal = []\n",
    "\n",
    "for i in tqdm(range(len(output_anneal[2]))):\n",
    "    model_tmp = models.TrajLoss(output_anneal[2][i],\n",
    "                            torch.tensor(sim.x, device = device), \n",
    "                            torch.tensor(sim.t_idx, device = device), \n",
    "                            dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = 0.5, M = M,\n",
    "                            lamda_reg = 0.05, lamda_cst = 1.0, sigma_cst = 5.0,\n",
    "                            branching_rate_fn = model_sim.branching_rate,\n",
    "                            sinkhorn_iters = 1_000, device = device, warm_start = True)\n",
    "    model_tmp.forward()\n",
    "    with torch.no_grad():\n",
    "        primal_anneal.append(model_tmp.forward_primal().item() + model_tmp.tau*model_tmp.lamda_reg*sum([models.entropy_est_knn(model_tmp.x[i, :, :], d = model_tmp.d, k = 2) for i in range(model_tmp.x.shape[0])]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "grave-upper",
   "metadata": {},
   "outputs": [],
   "source": [
    "err = [np.mean([dcor.energy_distance(x, y) for (x, y) in zip(output[2][i], output[2][-1])]) for i in range(len(output[2])-1)]\n",
    "err_anneal = [np.mean([dcor.energy_distance(x, y) for (x, y) in zip(output_anneal[2][i], output_anneal[2][-1])]) for i in range(len(output_anneal[2])-1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "latter-anxiety",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, PLT_CELL))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(output[1], label = \"MFL\")\n",
    "# plt.plot(output_anneal[1], label = \"MFL + Annealing\")\n",
    "plt.plot(primal_anneal, label = \"MFL + Annealing\")\n",
    "plt.ylim(2.25, 2.5)\n",
    "plt.title(\"Reduced objective $F$\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"$F$\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(np.sqrt(np.array(err)), label = \"MFL\")\n",
    "plt.plot(np.sqrt(np.array(err_anneal)), label = \"MFL + Annealing\")\n",
    "plt.title(\"Energy distance to final iterate\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"energy distance\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"appendix_annealing_a.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "japanese-australian",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, 3/2*PLT_CELL))\n",
    "plt.subplot(1, 2, 1)\n",
    "with torch.no_grad():\n",
    "    plt.scatter(model.x.reshape(-1, model_sim.dim)[:, 0], model.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 1, marker = \".\")\n",
    "plt.title(\"MFL\")\n",
    "plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "plt.xlim(-2.5, 2.5); plt.ylim(-1.5, 0.5)\n",
    "plt.subplot(1, 2, 2)\n",
    "with torch.no_grad():\n",
    "    plt.scatter(model_anneal.x.reshape(-1, model_sim.dim)[:, 0], model_anneal.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 1, marker = \".\")\n",
    "plt.title(\"MFL + Annealing\")\n",
    "plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "plt.xlim(-2.5, 2.5); plt.ylim(-1.5, 0.5)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"appendix_annealing_b.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "informative-circle",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (4*PLT_CELL, 1.125*PLT_CELL))\n",
    "for (i, j) in enumerate(np.array([1, 50, 250, 500, 2500])-1):\n",
    "    plt.subplot(1, 5, i+1)\n",
    "    im = plt.scatter(output[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], output[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    # plt.text(-1.75, -1.65, \"Iter %d\" % (j+1))\n",
    "    plt.title(\"Iter %d\" % (j+1))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")\n",
    "    # if i // 4 == 0:\n",
    "    #     plt.gca().get_xaxis().set_visible(False)\n",
    "    if i % 5 > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "plt.suptitle(\"MFL\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"appendix_annealing_b_new.pdf\")\n",
    "\n",
    "# fig.subplots_adjust(right=0.9)\n",
    "# cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "# cb = fig.colorbar(im, cax=cbar_ax)\n",
    "# cb.set_alpha(1)\n",
    "# cb.draw_all()\n",
    "# cbar_ax.set_title(\"$t$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "committed-province",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (4*PLT_CELL, 1.125*PLT_CELL))\n",
    "for (i, j) in enumerate(np.array([1, 50, 250, 500, 2500])-1):\n",
    "    plt.subplot(1, 5, i+1)\n",
    "    im = plt.scatter(output_anneal[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], output_anneal[2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    # plt.text(-1.75, -1.65, \"Iter %d\" % (j+1))\n",
    "    plt.title(\"Iter %d\" % (j+1))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")\n",
    "    # if i // 4 == 0:\n",
    "    #     plt.gca().get_xaxis().set_visible(False)\n",
    "    if i % 5 > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "plt.suptitle(\"MFL + Annealing\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"appendix_annealing_c_new.pdf\")\n",
    "\n",
    "# fig.subplots_adjust(right=0.9)\n",
    "# cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "# cb = fig.colorbar(im, cax=cbar_ax)\n",
    "# cb.set_alpha(1)\n",
    "# cb.draw_all()\n",
    "# cbar_ax.set_title(\"$t$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "timely-construction",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "grateful-moment",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bottom-village",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
