{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "appropriate-fountain",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "num_threads = \"16\"\n",
    "os.environ[\"OMP_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"OPENBLAS_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"MKL_NUM_THREADS\"] = num_threads\n",
    "os.environ[\"VECLIB_MAXIMUM_THREADS\"] = num_threads\n",
    "os.environ[\"NUMEXPR_NUM_THREADS\"] = num_threads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "distinguished-death",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import copy\n",
    "import scipy as sp\n",
    "from scipy import stats\n",
    "from sklearn import metrics\n",
    "import sys\n",
    "import ot\n",
    "import gwot\n",
    "from gwot import models, sim, ts, util\n",
    "import gwot.bridgesampling as bs\n",
    "import dcor\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "import importlib\n",
    "import models\n",
    "import random\n",
    "import mmd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ordinary-screw",
   "metadata": {},
   "outputs": [],
   "source": [
    "PLT_CELL = 2.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "massive-python",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "fnames_all = glob.glob(\"out_N_*.npy\")\n",
    "srand_all = np.array([int(f.split(\"_\")[4]) for f in fnames_all])\n",
    "lamda_all = np.array([float(f.split(\"_\")[6].split(\".npy\")[0]) for f in fnames_all])\n",
    "N_all = np.array([int(f.split(\"_\")[2]) for f in fnames_all])\n",
    "x_all = [np.load(f, allow_pickle = True).item(0)[\"model_x\"] for f in fnames_all]\n",
    "x_gt_all = [np.load(f, allow_pickle = True).item(0)[\"X_gt\"] for f in fnames_all]\n",
    "day_gt = np.load(fnames_all[0], allow_pickle = True).item(0)[\"day_gt\"]\n",
    "tsdata_all = [np.load(f, allow_pickle = True).item(0)[\"tsdata\"] for f in fnames_all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "biological-montgomery",
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames_all_gwot = glob.glob(\"out_gwot_N_*.npy\")\n",
    "srand_all_gwot = np.array([int(f.split(\"_\")[5]) for f in fnames_all_gwot])\n",
    "lamda_all_gwot = np.array([float(f.split(\"_\")[7].split(\".npy\")[0]) for f in fnames_all_gwot])\n",
    "N_all_gwot = np.array([int(f.split(\"_\")[3]) for f in fnames_all_gwot])\n",
    "x_all_gwot = [np.load(f, allow_pickle = True).item(0)[\"samples_gwot\"] for f in fnames_all_gwot]\n",
    "x_gt_all_gwot = [np.load(f, allow_pickle = True).item(0)[\"X_gt\"] for f in fnames_all_gwot]\n",
    "tsdata_all_gwot = [np.load(f, allow_pickle = True).item(0)[\"tsdata\"] for f in fnames_all_gwot]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "binary-guard",
   "metadata": {},
   "outputs": [],
   "source": [
    "days, day_idx = np.unique(day_gt, return_inverse = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "apart-currency",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    d_reconstruct = np.sqrt(np.array([[dcor.energy_distance(x_gt_all[j][day_idx == i, :], x_all[j][i, :]) for i in range(len(days))] for j in tqdm(range(len(x_all)), position = 0, leave = True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "amended-architect",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_gwot = np.sqrt(np.array([[dcor.energy_distance(x_gt_gwot_all[j][day_idx == i, :], x_gwot_all[j][i, :]) for i in range(len(days))] for j in tqdm(range(len(x_gwot_all)), position = 0, leave = True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "international-advancement",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_sample = np.sqrt(np.array([[dcor.energy_distance(x_gt_all[j][day_idx == i, :], tsdata_all[j].x[tsdata_all[j].t_idx == i, :]) for i in range(len(days))] for j in tqdm(range(len(x_all)), position = 0, leave = True)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acknowledged-planet",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_vals, _ = np.unique(N_all, return_index = True)\n",
    "N_vals_gwot, _ = np.unique(N_all_gwot, return_index = True)\n",
    "lamda_vals, _ = np.unique(lamda_all, return_index = True)\n",
    "lamda_vals_gwot, _ = np.unique(lamda_all_gwot, return_index = True)\n",
    "srand_vals, _ = np.unique(srand_all, return_index = True)\n",
    "srand_vals_gwot, _ = np.unique(srand_all_gwot, return_index = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "marine-substitute",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_reconstruct_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), d_reconstruct.shape[-1]), float(\"NaN\"))\n",
    "for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):\n",
    "    d_reconstruct_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_reconstruct[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "revised-plaza",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_gwot_tensor = np.full((len(N_vals_gwot), len(lamda_vals_gwot), len(srand_vals_gwot), d_gwot.shape[-1]), float(\"NaN\"))\n",
    "for (_N, _lamda, _srand) in zip(N_all_gwot, lamda_all_gwot, srand_all_gwot):\n",
    "    d_gwot_tensor[N_vals_gwot == _N, lamda_vals_gwot == _lamda, srand_vals_gwot == _srand, :] = d_gwot[(N_all_gwot == _N) & (lamda_all_gwot == _lamda) & (srand_all_gwot == _srand), :].flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "burning-token",
   "metadata": {},
   "outputs": [],
   "source": [
    "d_sample_tensor = np.full((len(N_vals), len(lamda_vals), len(srand_vals), d_sample.shape[-1]), float(\"NaN\"))\n",
    "for (_N, _lamda, _srand) in zip(N_all, lamda_all, srand_all):\n",
    "    d_sample_tensor[N_vals == _N, lamda_vals == _lamda, srand_vals == _srand, :] = d_sample[(N_all == _N) & (lamda_all == _lamda) & (srand_all == _srand), :].flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "endless-prisoner",
   "metadata": {},
   "outputs": [],
   "source": [
    "for l in np.unique(lamda_all):\n",
    "    plt.errorbar(days, d_reconstruct[lamda_all == l, :].mean(0), d_reconstruct[lamda_all == l, :].std(0), color = \"blue\")\n",
    "for l in np.unique(lamda_all_gwot):\n",
    "    plt.errorbar(days, d_gwot[lamda_all_gwot == l, :].mean(0), d_gwot[lamda_all_gwot == l, :].std(0), color = \"red\")\n",
    "for l in np.unique(lamda_all):\n",
    "    plt.errorbar(days, d_sample[lamda_all == l, :].mean(0), d_sample[lamda_all == l, :].std(0), color = \"green\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "monthly-modern",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "plt.errorbar(lamda_vals, np.nanmean(d_reconstruct_tensor, axis = (2, 3)).flatten(), np.nanstd(np.nanmean(d_reconstruct_tensor, axis = 3), axis = 2).flatten(), marker = \"o\", color = \"blue\")\n",
    "plt.hlines(d_sample_tensor[0, 0, :, :].mean(), min(lamda_vals), max(lamda_vals), color = \"green\")\n",
    "plt.hlines([d_sample_tensor[0, 0, :, :].mean() + d_sample_tensor[0, 0, :, :].mean(1).std(), \n",
    "            d_sample_tensor[0, 0, :, :].mean() - d_sample_tensor[0, 0, :, :].mean(1).std()], min(lamda_vals), max(lamda_vals), linestyle = 'dashed', color = \"green\", label = \"samples\")\n",
    "plt.title(\"Langevin\")\n",
    "plt.xlabel(\"$\\\\lambda$\")\n",
    "plt.legend()\n",
    "plt.ylim(0.55, 1.75)\n",
    "plt.xscale(\"log\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.errorbar(lamda_vals_gwot, np.nanmean(d_gwot_tensor, axis = (2, 3)).flatten(), np.nanstd(np.nanmean(d_gwot_tensor, axis = 3), axis = 2).flatten(), marker = \"o\", color = \"red\")\n",
    "plt.hlines(d_sample_tensor[0, 0, :, :].mean(), min(lamda_vals_gwot), max(lamda_vals_gwot), color = \"green\")\n",
    "plt.hlines([d_sample_tensor[0, 0, :, :].mean() + d_sample_tensor[0, 0, :, :].mean(1).std(), \n",
    "            d_sample_tensor[0, 0, :, :].mean() - d_sample_tensor[0, 0, :, :].mean(1).std()], min(lamda_vals_gwot), max(lamda_vals_gwot), linestyle = \"dashed\", color = \"green\", label = \"samples\")\n",
    "plt.legend()\n",
    "plt.xlabel(\"$\\\\lambda$\")\n",
    "plt.xscale(\"log\")\n",
    "plt.title(\"gWOT\")\n",
    "plt.ylim(0.55, 1.75)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "periodic-diversity",
   "metadata": {},
   "outputs": [],
   "source": [
    "lamda_vals_gwot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "chubby-cornwall",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (PLT_CELL, PLT_CELL))\n",
    "tmp = d_reconstruct_tensor[0, np.argmin(np.nanmean(d_reconstruct_tensor, axis = (2, 3)).flatten()), :, :]\n",
    "plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = \"o\", color = \"blue\", label = \"MFL\")\n",
    "tmp = d_gwot_tensor[0, np.argmin(np.nanmean(d_gwot_tensor, axis = (2, 3)).flatten()), :, :]\n",
    "plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = \"o\", color = \"red\", label = \"gWOT\")\n",
    "tmp = d_sample_tensor[0, 0, :, :]\n",
    "plt.errorbar(days, np.nanmean(tmp, 0).flatten(), np.nanstd(tmp, 0).flatten(), marker = \"o\", color = \"green\", label = \"Subsample\")\n",
    "plt.xlabel(\"day\")\n",
    "plt.ylabel(\"Energy Distance\")\n",
    "plt.legend(prop = {\"size\" : 8})\n",
    "plt.ylim(0.25, 2.5)\n",
    "plt.title(\"Error\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../reprogramming_distances.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "beautiful-hepatitis",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_gt_all[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "black-desire",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = np.where(lamda_all == 0.025)[0][0]\n",
    "M = 500\n",
    "\n",
    "fig = plt.figure(figsize = (3*PLT_CELL, PLT_CELL))\n",
    "plt.subplot(1, 3, 2)\n",
    "with torch.no_grad():\n",
    "    plt.scatter(x_all[i][:, :, 0], x_all[i][:, :, 1], c = np.kron(np.linspace(0, 1, len(days)), np.ones(M)), alpha = 0.5, s=  4)\n",
    "plt.xlabel(\"PC1\"); plt.ylabel(\"PC2\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "plt.title(\"MFL\")\n",
    "plt.xlim(-20, 20); plt.ylim(-20, 20)\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.scatter(tsdata_all[i].x[:, 0], tsdata_all[i].x[:, 1], c = tsdata_all[i].t_idx, alpha = 1, s = 4)\n",
    "plt.xlabel(\"PC1\"); plt.ylabel(\"PC2\")\n",
    "plt.title(\"Subsample\")\n",
    "plt.xlim(-20, 20); plt.ylim(-20, 20)\n",
    "plt.subplot(1, 3, 3)\n",
    "im = plt.scatter(x_gt_all[i][:, 0], x_gt_all[i][:, 1], c = day_gt, alpha = 0.05, s = 4, rasterized = True)\n",
    "plt.xlabel(\"PC1\"); plt.ylabel(\"PC2\")\n",
    "plt.title(\"Full dataset\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "plt.xlim(-20, 20); plt.ylim(-20, 20)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.subplots_adjust(right=0.9)\n",
    "cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"day\")\n",
    "\n",
    "plt.savefig(\"../reprogramming_snapshots.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "miniature-essay",
   "metadata": {},
   "outputs": [],
   "source": [
    "import anndata\n",
    "import umap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "welsh-studio",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ADATA_PATH = \"data_repr.h5ad\"\n",
    "# adata = anndata.read_h5ad(ADATA_PATH)\n",
    "# adata = adata[(adata.obs.day >= 2.5) & (adata.obs.day < 6.5), :]\n",
    "trans = umap.UMAP(n_neighbors = 25, verbose = True)\n",
    "X_gt_umap = trans.fit_transform(x_gt_all[i])\n",
    "\n",
    "plt.scatter(X_gt_umap[:, 0], X_gt_umap[:, 1], c = day_gt, alpha = 0.1, marker = \".\")\n",
    "\n",
    "X_sample_umap = trans.transform(tsdata_all[0].x)\n",
    "\n",
    "with torch.no_grad():\n",
    "    X_langevin_umap = trans.transform(x_all[i].reshape(-1, x_all[i].shape[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "affiliated-ebony",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize = (3*PLT_CELL, PLT_CELL))\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.scatter(X_langevin_umap[:, 0], X_langevin_umap[:, 1], c = np.kron(np.linspace(0, 1, len(days)), np.ones(M)), alpha = 0.5, s=  4)\n",
    "plt.xlabel(\"UMAP1\"); plt.ylabel(\"UMAP2\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "plt.title(\"MFL\")\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.scatter(X_sample_umap[:, 0], X_sample_umap[:, 1], c = tsdata_all[i].t_idx, alpha = 1, s = 4)\n",
    "plt.xlabel(\"UMAP1\"); plt.ylabel(\"UMAP2\")\n",
    "plt.title(\"Subsample\")\n",
    "plt.subplot(1, 3, 3)\n",
    "im = plt.scatter(X_gt_umap[:, 0], X_gt_umap[:, 1], c = day_gt, alpha = 0.05, s = 4, rasterized = True)\n",
    "plt.xlabel(\"UMAP1\"); plt.ylabel(\"UMAP2\")\n",
    "plt.title(\"Full dataset\")\n",
    "plt.gca().get_yaxis().set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.subplots_adjust(right=0.9)\n",
    "cbar_ax = fig.add_axes([0.925, 0.15, 0.025, 0.7])\n",
    "cb = fig.colorbar(im, cax=cbar_ax)\n",
    "cb.set_alpha(1)\n",
    "cb.draw_all()\n",
    "cbar_ax.set_title(\"day\")\n",
    "\n",
    "plt.savefig(\"../reprogramming_snapshots.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "contrary-bookmark",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
