{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from utils import pltManager\n",
    "import matplotlib.pyplot as plt\n",
    "from args_settings import select_args_specifications\n",
    "from load_and_test_models import visualize_args_latent_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translation = {\n",
    "    \"yahoo\":\"Yahoo\", \"yelp\":\"Yelp\", \"snli\":\"SNLI\", \"short_yelp\":\"Short-Yelp\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize(model_names, args_specifications, save_name=None):\n",
    "    # given models * four datasets * two modes * three axis tuples\n",
    "    save_path = f\"{save_name if save_name else model_names[0]}.png\"\n",
    "    lock_path = f\"lock-{save_path}.txt\"\n",
    "    if os.path.exists(save_path):\n",
    "        return\n",
    "    if os.path.exists(lock_path):\n",
    "        return\n",
    "    with open(lock_path, \"w\") as f:\n",
    "        f.write(\"\")\n",
    "    \n",
    "    assert len(model_names)*4 == len(args_specifications)\n",
    "    model_names = model_names*4\n",
    "    \n",
    "    select_axis = [[0, 1], [6, 7], [12, 13], [18, 19], [24, 25], [30, 31]]\n",
    "    mode = [\"aggregated\", \"center\"]\n",
    "    columns = len(select_axis) * len(mode)\n",
    "    lines = len(args_specifications)\n",
    "    pltm = pltManager(plt, columns, lines)\n",
    "    \n",
    "    processed = 0\n",
    "    for args_specification,model_name in tqdm(zip(args_specifications,model_names), total=len(model_names)):\n",
    "        with open(lock_path, \"w\") as f:\n",
    "            f.write(f\"{processed}/{len(model_names)}\")\n",
    "        visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0,\n",
    "                                    model_name=model_name, dataset=translation[args_specification[1]])\n",
    "        processed += 1\n",
    "    \n",
    "    pltm.plt.savefig(save_path, dpi=300)\n",
    "    os.remove(lock_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    VAE (default)   & -330.7 & -330.7 & 0.0 & 0.0 & 0 & 32 \\\\\n",
    "model_names = [\"VAE (default)\"]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"GaussianLSTMEncoder\"\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"VAE (default)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    cyclic-VAE      & -329.8 & -328.9 & 1.1 & 1.0 & 2 & 31 \\\\\n",
    "model_names = [\"cyclic-VAE\"]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--cycle\", \"20\",\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"GaussianLSTMEncoder\"\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"cyclic-VAE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    bow-VAE         & -330.5 & -330.5 & 0.0 & 0.0 & 0 & 32 \\\\\n",
    "model_names = [\"bow-VAE\"]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--add_bow\",\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"GaussianLSTMEncoder\"\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"bow-VAE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    skip-VAE        & -330.1 & -325.2 & 5.0 & 4.3 & 8 & 31 \\\\\n",
    "model_names = [\"skip-VAE\"]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--add_skip\",\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"GaussianLSTMEncoder\"\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"skip-VAE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    $\\delta$-VAE(0.15)  & -330.5 & -330.6 & 4.8 & 0.0 & 0 & 0 \\\\\n",
    "model_names = [r\"$\\delta$-VAE(0.15)\"]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"DeltaGaussianLSTMEncoder\"\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"delta-VAE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    BN-VAE(0.6)     & -327.6 & -321.1 & 6.6 & 5.9 & 32 & 32 \\\\\n",
    "#    %BN-VAE(0.7)     & -326.8 & -318.4 & 9.1 & 7.5 & 32 & 32 \\\\\n",
    "#    BN-VAE(0.9)     & -327.0 & -313.8 & 15.5 & 9.0 & 32 & 32 \\\\\n",
    "#    %BN-VAE(1.2)     & -330.9 & -310.1 & 26.2 & 9.2 & 32 & 0 \\\\\n",
    "#    %BN-VAE(1.5)     & -337.8 & -310.2 & 37.6 & 9.2 & 32 & 0 \\\\\n",
    "#    BN-VAE(1.8)     & -343.5 & -308.6 & 51.3 & 9.2 & 32 & 0 \\\\\n",
    "model_names = [f\"BN-VAE({gamma:.1f})\" for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--gamma\", str(gamma),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"BNGaussianLSTMEncoder\"\n",
    "    ] for gamma in [\n",
    "        0.6, 0.7, 0.9, 1.2, 1.5, 1.8\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"BN-VAEs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    FB-VAE(4)       & -329.8 & -328.4 & 3.9 & 1.8 & 32 & 32 \\\\\n",
    "#    %FB-VAE(9)       & -327.8 & -326.3 & 8.8 & 4.1 & 32 & 12 \\\\\n",
    "#    FB-VAE(16)      & -325.7 & -320.8 & 16.1 & 8.5 & 32 & 8 \\\\\n",
    "#    %FB-VAE(25)      & -333.4 & -316.2 & 25.8 & 9.2 & 32 & 0 \\\\\n",
    "#    %FB-VAE(36)      & -341.3 & -307.0 & 36.9 & 9.2 & 32 & 0 \\\\\n",
    "#    FB-VAE(49)      & -344.6 & -296.1 & 50.0 & 9.2 & 32 & 0 \\\\\n",
    "model_names = [f\"FB-VAE({target_kl:d})\" for target_kl in [4, 9, 16, 25, 36, 49]]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--target_kl\", str(target_kl),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"FineFBGaussianLSTMEncoder\"\n",
    "    ] for target_kl in [\n",
    "        4, 9, 16, 25, 36, 49\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"FB-VAEs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    %$\\beta$-VAE(0.8)    & -330.1 & -328.5 & 2.0 & 1.9 & 2 & 30 \\\\\n",
    "#    $\\beta$-VAE(0.4)    & -330.8 & -324.8 & 7.0 & 6.7 & 3 & 31 \\\\\n",
    "#    $\\beta$-VAE(0.2)    & -338.6 & -310.3 & 30.1 & 9.2 & 22 & 25 \\\\\n",
    "#    $\\beta$-VAE(0.1)    & -369.9 & -289.6 & 83.7 & 9.2 & 32 & 0 \\\\\n",
    "#    %$\\beta$-VAE(0.0)    & -445.2 & -280.3 & 178.8 & 9.2 & 32 & 0 \\\\\n",
    "model_names = [r\"$\\beta$-VAE({})\".format(beta) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--kl_beta\", str(kl_beta),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"GaussianLSTMEncoder\"\n",
    "    ] for kl_beta in [\n",
    "        1.0, 0.8, 0.4, 0.2, 0.1, 0.0\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"Beta-VAEs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#    MCo-VAE ($|b|=1$)   & -330.7 & -330.7 & 0.0 & 0.0 & 0 & 32 \\\\\n",
    "#    %MCo-VAE ($|b|=2$)   & -330.1 & -326.4 & 4.1 & 4.0 & 4 & 32 \\\\\n",
    "#    MCo-VAE ($|b|=4$)   & -330.4 & -318.3 & 14.3 & 9.1 & 11 & 32 \\\\\n",
    "#    %MCo-VAE ($|b|=8$)   & -338.2 & -308.2 & 32.1 & 9.1 & 30 & 32 \\\\\n",
    "#    %MCo-VAE ($|b|=16$)  & -349.5 & -295.0 & 57.6 & 9.1 & 32 & 32 \\\\\n",
    "#    MCo-VAE ($|b|=32$)  & -355.4 & -294.1 & 65.2 & 9.1 & 32 & 32 \\\\\n",
    "model_names = [r\"DG-VAE ($|b|={}$)\".format(agg_size) for agg_size in [1,2,4,8,16,32]]\n",
    "args_specifications = [\n",
    "    [\n",
    "        \"--dataset\", str(dataset),\n",
    "        \"--encoder_class\", str(encoder_class),\n",
    "        \"--agg_size\", str(agg_size),\n",
    "    ] for dataset in [\n",
    "        \"yahoo\",\"yelp\",\"snli\",\"short_yelp\"\n",
    "    ] for encoder_class in [\n",
    "        \"MCoGaussianLSTMEncoder\"\n",
    "    ] for agg_size in [\n",
    "        1, 2, 4, 8, 16, 32\n",
    "    ]\n",
    "]\n",
    "visualize(model_names, args_specifications, \"DG-VAEs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    save_path = \"collapse_and_hole_5.png\"\n",
    "    if not os.path.exists(f\"{save_path}\") and not os.path.exists(f\"lock-{save_path}\"):\n",
    "        with open(f\"lock-{save_path}\", \"w\") as f:\n",
    "            f.write(\"\")\n",
    "        \n",
    "        args_specifications = [\n",
    "            [\n",
    "                \"--dataset\", str(dataset),\n",
    "                \"--encoder_class\", str(encoder_class),\n",
    "                str(specific_1), str(specific_2)\n",
    "            ] for dataset in [\n",
    "                \"yahoo\",\n",
    "            ] for encoder_class, specific_1, specific_2 in [\n",
    "                (\"GaussianLSTMEncoder\", \"\", \"\"),\n",
    "                (\"BNGaussianLSTMEncoder\", \"--gamma\", \"1.2\"),\n",
    "                (\"GaussianLSTMEncoder\", \"--kl_beta\", \"0.4\"),\n",
    "                (\"FineFBGaussianLSTMEncoder\", \"--target_kl\", \"36\"),\n",
    "                (\"MCoGaussianLSTMEncoder\", \"\", \"\")\n",
    "            ]\n",
    "        ]\n",
    "        model_names = [\"VAE\", \"BN-VAE(1.2)\", r\"$\\beta$-VAE(0.4)\",  \"FB-VAE(36)\",\"DG-VAE\"]\n",
    "\n",
    "        select_datasets = range(1)\n",
    "        select_models = range(len(model_names))\n",
    "        select_axis = [\"max_var\"]\n",
    "        mode = [\"aggregated\", \"center\"]\n",
    "        columns, lines = len(select_axis)*len(select_models), len(select_datasets)*len(mode)\n",
    "        pltm = pltManager(plt, columns, lines)\n",
    "\n",
    "        num_models = len(select_models)\n",
    "        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]\n",
    "        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):\n",
    "            visualize_args_latent_space(args_specification, pltm, select_axis, [\"aggregated\"], verbose=0, model_name=model_name)\n",
    "        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):\n",
    "            visualize_args_latent_space(args_specification, pltm, select_axis, [\"center\"], verbose=0, model_name=model_name)\n",
    "        #pltm.plt.suptitle(\"title\", fontsize=30, ha=\"center\")\n",
    "\n",
    "        pltm.plt.savefig(f\"{save_path}\", dpi=300)\n",
    "        os.remove(f\"lock-{save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    save_path = \"gaussian_models_max.png\"\n",
    "    if not os.path.exists(f\"{save_path}\") and not os.path.exists(f\"lock-{save_path}\"):\n",
    "        with open(f\"lock-{save_path}\", \"w\") as f:\n",
    "            f.write(\"\")\n",
    "        \n",
    "        args_specifications = [\n",
    "            [\n",
    "                \"--dataset\", str(dataset),\n",
    "                \"--encoder_class\", str(encoder_class),\n",
    "                str(specific_1), str(specific_2)\n",
    "            ] for dataset in [\n",
    "                \"yahoo\",\n",
    "            ] for encoder_class, specific_1, specific_2 in ([\n",
    "                (\"BNGaussianLSTMEncoder\", \"--gamma\", str(gamma)) for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]\n",
    "            ] + [\n",
    "                (\"FineFBGaussianLSTMEncoder\", \"--target_kl\", str(target_kl)) for target_kl in [4, 9, 16, 25, 36, 49]\n",
    "            ] + [\n",
    "                (\"GaussianLSTMEncoder\", \"--kl_beta\", str(beta)) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]\n",
    "            ] + [\n",
    "                (\"MCoGaussianLSTMEncoder\", \"--agg_size\", str(agg_size)) for agg_size in [1,2,4,8,16,32]\n",
    "            ])\n",
    "        ]\n",
    "        model_names = [\n",
    "            f\"BN-VAE({gamma:.1f})\" for gamma in [0.6, 0.7, 0.9, 1.2, 1.5, 1.8]\n",
    "        ] + [\n",
    "            f\"FB-VAE({target_kl:d})\" for target_kl in [4, 9, 16, 25, 36, 49]\n",
    "        ] + [\n",
    "            r\"$\\beta$-VAE({})\".format(beta) for beta in [1.0, 0.8, 0.4, 0.2, 0.1, 0.0]\n",
    "        ] + [\n",
    "            r\"DG-VAE ($|b|={}$)\".format(agg_size) for agg_size in [1,2,4,8,16,32]\n",
    "        ]\n",
    "        print(model_names)\n",
    "\n",
    "        select_datasets = range(1)\n",
    "        select_models = range(24)\n",
    "        select_axis = [\"max_var\"]\n",
    "        mode = [\"aggregated\", \"center\"]\n",
    "        columns, lines = len(select_axis)*len(mode)*6, len(select_datasets)*4\n",
    "        pltm = pltManager(plt, columns, lines)\n",
    "\n",
    "        num_models = len(select_models)\n",
    "        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]\n",
    "        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)):\n",
    "            visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0, model_name=model_name)\n",
    "        #pltm.plt.suptitle(\"title\", fontsize=30, ha=\"center\")\n",
    "\n",
    "        pltm.plt.savefig(f\"{save_path}\", dpi=300)\n",
    "        os.remove(f\"lock-{save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    save_path = \"collapse_and_hole.png\"\n",
    "    if not os.path.exists(f\"{save_path}\") and not os.path.exists(f\"lock-{save_path}\"):\n",
    "        with open(f\"lock-{save_path}\", \"w\") as f:\n",
    "            f.write(\"\")\n",
    "        \n",
    "        args_specifications = [\n",
    "            [\n",
    "                \"--dataset\", str(dataset),\n",
    "                \"--encoder_class\", str(encoder_class),\n",
    "                str(specific_1), str(specific_2)\n",
    "            ] for dataset in [\n",
    "                \"yahoo\",\"yelp\"\n",
    "            ] for encoder_class, specific_1, specific_2 in [\n",
    "                (\"GaussianLSTMEncoder\", \"\", \"\"),\n",
    "                (\"GaussianLSTMEncoder\", \"--kl_beta\", \"0.0\"),\n",
    "                (\"MCoGaussianLSTMEncoder\", \"\", \"\")\n",
    "            ]\n",
    "        ]\n",
    "        model_names = [\"VAE\", \"AE\", \"DG-VAE\"]\n",
    "\n",
    "        select_datasets = range(1)\n",
    "        select_models = range(3)\n",
    "        select_axis = [\"mid_var\"]\n",
    "        mode = [\"aggregated\", \"center\"]\n",
    "        columns, lines = len(select_axis)*len(mode)*len(select_models), len(select_datasets)\n",
    "        pltm = pltManager(plt, columns, lines)\n",
    "\n",
    "        num_models = len(select_models)\n",
    "        args_specifications = [args_specifications[i+num_models*j] for j in select_datasets for i in select_models]\n",
    "        for i,(args_specification,model_name) in tqdm(enumerate(zip(args_specifications,model_names)), total=len(model_names)*len(select_datasets)):\n",
    "            visualize_args_latent_space(args_specification, pltm, select_axis, mode, verbose=0, model_name=model_name)\n",
    "        #pltm.plt.suptitle(\"title\", fontsize=30, ha=\"center\")\n",
    "\n",
    "        pltm.plt.savefig(f\"{save_path}\", dpi=300)\n",
    "        os.remove(f\"lock-{save_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
