{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "discrete-explosion",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"8\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "heated-compact",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.autograd import grad, Variable\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import copy\n",
    "import scipy as sp\n",
    "from scipy import stats\n",
    "from sklearn import metrics\n",
    "import sys\n",
    "import ot\n",
    "import gwot\n",
    "from gwot import models, sim, ts, util\n",
    "import gwot.bridgesampling as bs\n",
    "import dcor\n",
    "\n",
    "sys.path.append(\"fig1_batch\")\n",
    "import importlib\n",
    "import models\n",
    "import random\n",
    "import model_fig1 as model_sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sufficient-raise",
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(model_sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "tribal-proposition",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "least-search",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLT_CELL = 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "operating-ownership",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seed\n",
    "SRAND = 0\n",
    "torch.manual_seed(SRAND)\n",
    "np.random.seed(SRAND)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "welsh-assets",
   "metadata": {},
   "outputs": [],
   "source": [
    "M = 100\n",
    "N_vals = [64, 1]\n",
    "N0 = 64\n",
    "\n",
    "# setup simulation object\n",
    "sims_all = [gwot.sim.Simulation(V = model_sim.Psi, dV = model_sim.dPsi, birth_death = False, \n",
    "                          N = np.array([N0, ] + [N, ]*(model_sim.T-2) + [N0, ]),\n",
    "                          T = model_sim.T, \n",
    "                          d = model_sim.dim, \n",
    "                          D = model_sim.D, \n",
    "                          t_final = model_sim.t_final, \n",
    "                          ic_func = model_sim.ic_func, \n",
    "                          pool = None) for N in N_vals]\n",
    "\n",
    "# sample from simulation\n",
    "for s in sims_all:\n",
    "    s.sample(steps_scale = int(model_sim.sim_steps/s.T));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "average-duplicate",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_gt = copy.deepcopy(sims_all[0])\n",
    "sim_gt.N = np.array([500, ]*model_sim.T)\n",
    "sim_gt.sample(steps_scale = int(model_sim.sim_steps/sims_all[0].T));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "weighted-interval",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot samples\n",
    "i = 0\n",
    "plt.scatter(sims_all[i].x[:, 0], sims_all[i].x[:, 1], alpha = 0.25, c = sims_all[i].t_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "thermal-bleeding",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(sim_gt.x[:, 0], sim_gt.x[:, 1], alpha = 0.25, c = sim_gt.t_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "color-occurrence",
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 0\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.scatter(np.linspace(0, model_sim.t_final, model_sim.T)[sims_all[-1].t_idx], sims_all[-1].x[:, k], alpha = 0.5, color = \"red\")\n",
    "plt.ylim(-2.0, 2.0)\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.scatter(np.linspace(0, model_sim.t_final, model_sim.T)[sim_gt.t_idx], sim_gt.x[:, k], alpha = 0.01, color = \"blue\")\n",
    "plt.ylim(-2.0, 2.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "oriented-juice",
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(models)\n",
    "importlib.reload(model_sim)\n",
    "torch.set_num_threads(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "industrial-haven",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_all = [models.TrajLoss(torch.randn(model_sim.T, M, model_sim.dim)*0.1,\n",
    "                        torch.tensor(s.x, device = device), \n",
    "                        torch.tensor(s.t_idx, device = device), \n",
    "                        dt = model_sim.t_final/model_sim.T, tau = model_sim.D, sigma = None, M = M,\n",
    "                        lamda_reg = 0.05, lamda_cst = 0, sigma_cst = float(\"Inf\"),\n",
    "                        branching_rate_fn = model_sim.branching_rate,\n",
    "                        sinkhorn_iters = 250, device = device, warm_start = True) for s in sims_all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "virgin-economics",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs_all = [models.optimize(m, n_iter = 2_500, eta_final = 0.1, tau_final = model_sim.D, sigma_final = 0.5, N = M, temp_init = 1.0, temp_ratio = 1.0, dim = model_sim.dim, tloss = m, print_interval = 50) for m in models_all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acceptable-leave",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(outputs_all[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "present-adventure",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.vstack([o[1] for o in outputs_all]).T)\n",
    "plt.title(\"Objective: primal\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "herbal-dispute",
   "metadata": {},
   "outputs": [],
   "source": [
    "u = np.full(model_sim.dim, 1); u[2:] = 0\n",
    "u = u/np.linalg.norm(u)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "israeli-vessel",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = -1\n",
    "with torch.no_grad():\n",
    "    plt.scatter(np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), models_all[i].x.reshape(-1, model_sim.dim)[:, 0], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.1)\n",
    "# plt.scatter(np.linspace(0, t_final, T)[simodel_sim.t_idx], np.dot(simodel_sim.x, u), alpha = 0.5, color = \"red\", marker = \"x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "inappropriate-providence",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "plt.scatter(sims_all[i].x[:, 0], sims_all[i].x[:, 1], color = \"red\", marker = \"x\")\n",
    "with torch.no_grad():\n",
    "    plt.scatter(models_all[i].x.reshape(-1, model_sim.dim)[:, 0], models_all[i].x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25)\n",
    "plt.axis('auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "threaded-accommodation",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_all_gwot = [gwot.models.OTModel(s, lamda_reg = 0.0025,\n",
    "          eps_df = 0.01*torch.ones(s.T), \n",
    "          growth_constraint=\"exact\", \n",
    "          pi_0 = \"uniform\",\n",
    "          use_keops = False,\n",
    "          device = device) for s in sims_all]\n",
    "\n",
    "for m in models_all_gwot:\n",
    "    m.solve_lbfgs(steps = 10, \n",
    "                    max_iter = 50, \n",
    "                    lr = 1,\n",
    "                    history_size = 50, \n",
    "                    line_search_fn = 'strong_wolfe', \n",
    "                    factor = 2, \n",
    "                    tol = 1e-5, \n",
    "                    retry_max = 0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "referenced-lover",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_gwot_samples_all = [np.stack([m.ts.x[np.random.choice(m.ts.x.shape[0], size = M, p = torch.nn.functional.normalize(m.get_R(t).detach(), p = 1.0, dim = 0)), :] for t in range(model_sim.T)]) for m in models_all_gwot]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "stunning-rubber",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_paths_gt = 250\n",
    "N_paths = 50\n",
    "paths_gt = sim_gt.sample_trajectory(steps_scale = int(model_sim.sim_steps/sim_gt.T), N = N_paths_gt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "editorial-aurora",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.gridspec import GridSpec\n",
    "fig = plt.figure(figsize = (3*PLT_CELL, 2*PLT_CELL))\n",
    "gs = GridSpec(2, 3, figure=fig)\n",
    "\n",
    "for (i, s) in enumerate(sims_all):\n",
    "    ax = fig.add_subplot(gs[len(models_all)-i-1, 0])\n",
    "    # plt.scatter(np.linspace(0, t_final, T)[s.t_idx], np.dot(s.x, u), alpha = 0.25, color = \"red\", marker = \".\")\n",
    "    # plt.plot(np.linspace(0, t_final, paths_gt.shape[1]), np.dot(paths_gt, u).T, color = 'grey', alpha = 0.05);\n",
    "    ax.plot(paths_gt[:, :, 0].T, paths_gt[:, :, 1].T, color = 'grey', alpha = 0.05);\n",
    "    # ax.scatter(s.x[:, 0], s.x[:, 1], alpha = 0.5, color = \"red\", marker = \".\")\n",
    "    ax.scatter(s.x[:, 0], s.x[:, 1], alpha = 0.5, c = s.t_idx/s.T, marker = \".\")\n",
    "    ax.set_ylim(-1.75, 0.5); ax.set_xlim(-2, 2)\n",
    "    ax.set_xlabel(\"$x_0$\"); ax.set_ylabel(\"$x_1$\")\n",
    "    # ax.set_title(\"N = %d\" % N_vals[i])\n",
    "    ax.text(0.75, 0.25, \"N = %d\" % N_vals[i])\n",
    "    if i > 0:\n",
    "        ax.set_title(\"Samples\")\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "\n",
    "for (i, m) in enumerate(models_all):\n",
    "    ax = fig.add_subplot(gs[len(models_all)-i-1, 1])\n",
    "    with torch.no_grad():\n",
    "        paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = m.x.cpu().numpy(), \n",
    "                            get_gamma_fn = lambda i : m.loss_reg.ot_losses[i].coupling().cpu(), num_couplings = model_sim.T-1)\n",
    "        # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "        ax.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "        # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(m.x.reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "        im = ax.scatter(m.x.reshape(-1, model_sim.dim)[:, 0], m.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    ax.set_xlabel(\"$x_0$\"); ax.set_ylabel(\"$x_1$\")\n",
    "    ax.set_ylim(-1.75, 0.5); ax.set_xlim(-2, 2)\n",
    "    # ax.set_title(\"N = %d\" % N_vals[i])\n",
    "    ax.text(0.75, 0.25, \"N = %d\" % N_vals[i])\n",
    "    ax.get_yaxis().set_visible(False)\n",
    "    if i > 0:\n",
    "        ax.set_title(\"MFL\")\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "\n",
    "for (i, x) in enumerate(X_gwot_samples_all):\n",
    "    ax = fig.add_subplot(gs[len(models_all)-i-1, 2])\n",
    "    with torch.no_grad():\n",
    "        paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = [sims_all[i].x, ]*sims_all[i].T, \n",
    "                                get_gamma_fn = lambda j : models_all_gwot[i].get_coupling_reg(j, K = models_all_gwot[i].get_K(j)), num_couplings = sims_all[i].T-1)\n",
    "        # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "        ax.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "    # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(x, u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    ax.scatter(x.reshape(-1, model_sim.dim)[:, 0], x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    ax.set_ylim(-1.75, 0.5); ax.set_xlim(-2, 2)\n",
    "    ax.set_xlabel(\"$x_0$\"); ax.set_ylabel(\"$x_1$\")\n",
    "    # ax.set_title(\"N = %d\" % N_vals[i])\n",
    "    ax.text(0.75, 0.25, \"N = %d\" % N_vals[i])\n",
    "    ax.get_yaxis().set_visible(False)\n",
    "    if i > 0:\n",
    "        ax.set_title(\"gWOT\")\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "\n",
    "plt.tight_layout() \n",
    "        \n",
    "fig.subplots_adjust(right=0.9)\n",
    "cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"$t$\")\n",
    " \n",
    "plt.savefig(\"fig1_langevin_new.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "desperate-desktop",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_paths = 50\n",
    "plt.figure(figsize = (3*PLT_CELL, 1.25*PLT_CELL))\n",
    "for (i, m) in enumerate(models_all):\n",
    "    plt.subplot(1, len(N_vals), i+1)\n",
    "    with torch.no_grad():\n",
    "        paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = m.x.cpu().numpy(), \n",
    "                            get_gamma_fn = lambda i : m.loss_reg.ot_losses[i].coupling().cpu(), num_couplings = model_sim.T-1)\n",
    "        # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "        plt.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "        # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(m.x.reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "        plt.scatter(m.x.reshape(-1, model_sim.dim)[:, 0], m.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    plt.title(\"N = %d\" % N_vals[i])\n",
    "    if i > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "plt.suptitle(\"Langevin\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"fig1_langevin.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "attractive-warner",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, 1.25*PLT_CELL))\n",
    "for (i, x) in enumerate(X_gwot_samples_all):\n",
    "    plt.subplot(1, len(N_vals), i+1)\n",
    "    with torch.no_grad():\n",
    "        paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = [sims_all[i].x, ]*sims_all[i].T, \n",
    "                                get_gamma_fn = lambda j : models_all_gwot[i].get_coupling_reg(j, K = models_all_gwot[i].get_K(j)), num_couplings = sims_all[i].T-1)\n",
    "        # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "        plt.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "    # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(x, u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.scatter(x.reshape(-1, model_sim.dim)[:, 0], x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "    if i > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "    plt.title(\"N = %d\" % N_vals[i])\n",
    "plt.suptitle(\"gWOT\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"fig1_gwot.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "crazy-aurora",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, 1.25*PLT_CELL))\n",
    "for (i, s) in enumerate(sims_all):\n",
    "    plt.subplot(1, len(N_vals), i+1)\n",
    "    # plt.scatter(np.linspace(0, t_final, T)[s.t_idx], np.dot(s.x, u), alpha = 0.25, color = \"red\", marker = \".\")\n",
    "    # plt.plot(np.linspace(0, t_final, paths_gt.shape[1]), np.dot(paths_gt, u).T, color = 'grey', alpha = 0.05);\n",
    "    plt.plot(paths_gt[:, :, 0].T, paths_gt[:, :, 1].T, color = 'grey', alpha = 0.05);\n",
    "    plt.scatter(s.x[:, 0], s.x[:, 1], alpha = 0.5, color = \"red\", marker = \".\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    if i > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "    plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "    plt.title(\"N = %d\" % N_vals[i])\n",
    "plt.suptitle(\"Samples\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"fig1_samples.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hollow-fault",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (3*PLT_CELL, PLT_CELL))\n",
    "i = 1\n",
    "\n",
    "plt.subplot(1, 4, 1)\n",
    "plt.plot(paths_gt[:, :, 0].T, paths_gt[:, :, 1].T, color = 'grey', alpha = 0.05);\n",
    "plt.scatter(sim_gt.x[:, 0], sim_gt.x[:, 1], alpha = 0.025, color = \"blue\", marker = \".\")\n",
    "plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "plt.xlabel(\"x\"); plt.ylabel(\"y\")\n",
    "plt.title(\"Ground truth\")\n",
    "\n",
    "plt.subplot(1, 4, 2)\n",
    "plt.plot(paths_gt[:, :, 0].T, paths_gt[:, :, 1].T, color = 'grey', alpha = 0.05);\n",
    "plt.scatter(sims_all[i].x[:, 0], sims_all[i].x[:, 1], alpha = 0.5, color = \"red\", marker = \".\")\n",
    "plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "plt.xlabel(\"x\")\n",
    "plt.title(\"Samples\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "\n",
    "plt.subplot(1, 4, 3)\n",
    "with torch.no_grad():\n",
    "    paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = m.x.cpu().numpy(), \n",
    "                        get_gamma_fn = lambda i : m.loss_reg.ot_losses[i].coupling().cpu(), num_couplings = model_sim.T-1)\n",
    "    # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "    plt.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "    # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(m.x.reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.scatter(m.x.reshape(-1, model_sim.dim)[:, 0], m.x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "plt.xlabel(\"x\")\n",
    "plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "plt.title(\"Reconstruction\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.subplot(1, 4, 4)\n",
    "# with torch.no_grad():\n",
    "#     paths = bs.sample_paths(None, N = N_paths, coord = True, x_all = [sims_all[i].x, ]*sims_all[i].T, \n",
    "#                             get_gamma_fn = lambda j : models_all_gwot[i].get_coupling_reg(j, K = models_all_gwot[i].get_K(j)), num_couplings = sims_all[i].T-1)\n",
    "#     # plt.plot(np.linspace(0, t_final, paths.shape[1]), np.dot(paths, u).T, color = 'grey', alpha = 0.1);\n",
    "#     plt.plot(paths[:, :, 0].T, paths[:, :, 1].T, color = 'grey', alpha = 0.25);\n",
    "# # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(x, u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "# plt.scatter(x.reshape(-1, model_sim.dim)[:, 0], x.reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.arange(model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "# plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "# plt.gca().get_yaxis().set_visible(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "checked-digest",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (3.5*PLT_CELL, 1*PLT_CELL))\n",
    "for (i, j) in enumerate(np.array([1, 50, 500, 2500])-1):\n",
    "    plt.subplot(1, 4, i+1)\n",
    "    # plt.scatter(np.kron(np.linspace(0, t_final, T), np.ones(M)), np.dot(outputs_all[0][2][j, :, :, :].reshape(-1, dim), u), c = np.kron(np.arange(T), np.ones(M)), alpha = 0.1, marker = '.')\n",
    "    im = plt.scatter(outputs_all[0][2][j, :, :, :].reshape(-1, model_sim.dim)[:, 0], outputs_all[0][2][j, :, :, :].reshape(-1, model_sim.dim)[:, 1], c = np.kron(np.linspace(0, model_sim.t_final, model_sim.T), np.ones(M)), alpha = 0.25, marker = \".\")\n",
    "    plt.ylim(-1.75, 0.5); plt.xlim(-2, 2)\n",
    "    # plt.text(-1.75, -1.65, \"Iter %d\" % (j+1))\n",
    "    plt.title(\"Iter %d\" % (j+1))\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")\n",
    "    # if i // 4 == 0:\n",
    "    #     plt.gca().get_xaxis().set_visible(False)\n",
    "    if i % 4 > 0:\n",
    "        plt.gca().get_yaxis().set_visible(False)\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.subplots_adjust(right=0.9)\n",
    "cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"$t$\")\n",
    "\n",
    "plt.savefig(\"fig2_trainingdynamics.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "matched-distinction",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (PLT_CELL, PLT_CELL))\n",
    "plt.plot(outputs_all[0][1], 'k')\n",
    "plt.title(\"Reduced objective $F$\")\n",
    "plt.xlabel(\"Iteration\")\n",
    "plt.ylabel(\"$F$\")\n",
    "plt.tight_layout()\n",
    "plt.ylim(2.25, 3.0)\n",
    "plt.savefig(\"fig2_objective.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
