{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.optim as optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import pickle\n",
    "\n",
    "from src.models import *\n",
    "from src.utils import *\n",
    "from src.utils_data import *\n",
    "import argparse\n",
    "import time\n",
    "import logging\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_type = 'WS'\n",
    "graph_size = 50\n",
    "num_unroll = 20\n",
    "\n",
    "logging.basicConfig(filename='logs/L2G_{}_m{}_x{}.log'.format(graph_type, graph_size, num_unroll),\n",
    "                    filemode='w',\n",
    "                    format='%(asctime)s - %(message)s',\n",
    "                    datefmt='%d-%b-%y %H:%M:%S',\n",
    "                    level=logging.INFO)\n",
    "\n",
    "console = logging.StreamHandler()\n",
    "console.setLevel(logging.INFO)\n",
    "formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%d-%b-%y %H:%M:%S')\n",
    "console.setFormatter(formatter)\n",
    "logging.getLogger().addHandler(console)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate synthetic WS graphs\n",
    "\n",
    "graph_type = 'WS'\n",
    "edge_type = 'lognormal'\n",
    "graph_size = 50\n",
    "\n",
    "graph_hyper = {'k': 5,\n",
    "               'p': 0.3}\n",
    "\n",
    "data = generate_WS_parallel(num_samples=8064,\n",
    "                            num_signals=3000,\n",
    "                            num_nodes=graph_size,\n",
    "                            graph_hyper=graph_hyper,\n",
    "                            weighted=edge_type,\n",
    "                            weight_scale=True)\n",
    "\n",
    "with open('data/dataset_{}_{}nodes.pickle'.format(graph_type, graph_size), 'wb') as handle:\n",
    "    pickle.dump(data, handle, protocol=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading data at  data/dataset_WS_50nodes.pickle\n",
      "successfully loading: train 6400, val 1600, test 64, batch 32\n"
     ]
    }
   ],
   "source": [
    "# load data\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "data_dir = 'data/dataset_{}_{}nodes.pickle'.format(graph_type, graph_size)\n",
    "train_loader, val_loader, test_loader = data_loading(data_dir, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "for _, W in test_loader:\n",
    "    eg = torch_sqaureform_to_matrix(W, device='cpu')\n",
    "\n",
    "plt.figure()\n",
    "sns.heatmap(eg[4])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 11:54:46 | learn2graph(\n",
      "  (vae): TopoDiffVAE(\n",
      "    (enc): GraphEnc(\n",
      "      (conv1): GraphConvLayer()\n",
      "      (conv2): GraphConvLayer()\n",
      "      (fc1): Linear(in_features=64, out_features=32, bias=True)\n",
      "      (fc2): Linear(in_features=32, out_features=16, bias=True)\n",
      "    )\n",
      "    (f_mean): Sequential(\n",
      "      (0): Linear(in_features=16, out_features=16, bias=True)\n",
      "    )\n",
      "    (f_var): Sequential(\n",
      "      (0): Linear(in_features=16, out_features=16, bias=True)\n",
      "    )\n",
      "    (dec): Sequential(\n",
      "      (0): Linear(in_features=1241, out_features=1633, bias=True)\n",
      "      (1): Tanh()\n",
      "      (2): Linear(in_features=1633, out_features=1225, bias=True)\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "num_unroll = 20\n",
    "graph_size = 50\n",
    "n_hid = 32\n",
    "n_latent = 16\n",
    "n_nodeFeat = 1\n",
    "n_graphFeat = 16\n",
    "\n",
    "lr = 1e-02\n",
    "lr_decay = 0.95\n",
    "\n",
    "net = learn2graph(num_unroll, graph_size, n_hid,\n",
    "                  n_latent, n_nodeFeat, n_graphFeat).to(device)\n",
    "\n",
    "optimizer = optim.Adam(net.parameters(), lr=lr)\n",
    "scheduler = lr_scheduler.ExponentialLR(optimizer, lr_decay)\n",
    "\n",
    "logging.info(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 11:55:31 | Epoch 0001 | lr: 0.00902500 | Time(s): 40.8966\n",
      "03-Jun-21 11:55:31 | == train Loss <unroll: 431.1791 | vae : 17433609.2710 | kl : 17430613.8530>\n",
      "03-Jun-21 11:55:31 | == gmse <train: 473.5473 | val: 0.4936> \n",
      "03-Jun-21 11:56:12 | Epoch 0002 | lr: 0.00857375 | Time(s): 40.7864\n",
      "03-Jun-21 11:56:12 | == train Loss <unroll: 4.6488 | vae : 3.0605 | kl : 0.0197>\n",
      "03-Jun-21 11:56:12 | == gmse <train: 0.5693 | val: 0.5447> \n",
      "03-Jun-21 11:56:53 | Epoch 0003 | lr: 0.00814506 | Time(s): 40.8395\n",
      "03-Jun-21 11:56:53 | == train Loss <unroll: 4.5149 | vae : 3.5714 | kl : 0.0193>\n",
      "03-Jun-21 11:56:53 | == gmse <train: 0.6749 | val: 0.6758> \n",
      "03-Jun-21 11:57:33 | Epoch 0004 | lr: 0.00773781 | Time(s): 40.8260\n",
      "03-Jun-21 11:57:33 | == train Loss <unroll: 4.4031 | vae : 4.1391 | kl : 0.0190>\n",
      "03-Jun-21 11:57:33 | == gmse <train: 0.7915 | val: 0.7207> \n",
      "03-Jun-21 11:58:14 | Epoch 0005 | lr: 0.00735092 | Time(s): 40.8396\n",
      "03-Jun-21 11:58:14 | == train Loss <unroll: 4.4537 | vae : 5.2931 | kl : 0.0185>\n",
      "03-Jun-21 11:58:14 | == gmse <train: 1.0259 | val: 0.9657> \n",
      "03-Jun-21 11:58:55 | Epoch 0006 | lr: 0.00698337 | Time(s): 40.8422\n",
      "03-Jun-21 11:58:55 | == train Loss <unroll: 4.8447 | vae : 8.2592 | kl : 0.0180>\n",
      "03-Jun-21 11:58:55 | == gmse <train: 1.6245 | val: 1.5454> \n",
      "03-Jun-21 11:59:36 | Epoch 0007 | lr: 0.00663420 | Time(s): 40.8112\n",
      "03-Jun-21 11:59:36 | == train Loss <unroll: 5.8312 | vae : 14.6033 | kl : 0.0175>\n",
      "03-Jun-21 11:59:36 | == gmse <train: 2.8974 | val: 3.0107> \n",
      "03-Jun-21 12:00:17 | Epoch 0008 | lr: 0.00630249 | Time(s): 40.8457\n",
      "03-Jun-21 12:00:17 | == train Loss <unroll: 6.8572 | vae : 21.3574 | kl : 0.0170>\n",
      "03-Jun-21 12:00:17 | == gmse <train: 4.2364 | val: 4.0696> \n",
      "03-Jun-21 12:00:57 | Epoch 0009 | lr: 0.00598737 | Time(s): 40.8097\n",
      "03-Jun-21 12:00:57 | == train Loss <unroll: 7.6742 | vae : 27.2858 | kl : 0.0165>\n",
      "03-Jun-21 12:00:57 | == gmse <train: 5.3857 | val: 5.0986> \n",
      "03-Jun-21 12:01:38 | Epoch 0010 | lr: 0.00568800 | Time(s): 40.8366\n",
      "03-Jun-21 12:01:38 | == train Loss <unroll: 8.2627 | vae : 32.6528 | kl : 0.0160>\n",
      "03-Jun-21 12:01:38 | == gmse <train: 6.3942 | val: 5.6941> \n",
      "03-Jun-21 12:02:19 | Epoch 0011 | lr: 0.00540360 | Time(s): 40.8483\n",
      "03-Jun-21 12:02:19 | == train Loss <unroll: 8.2039 | vae : 35.2208 | kl : 0.0155>\n",
      "03-Jun-21 12:02:19 | == gmse <train: 6.8267 | val: 6.7381> \n",
      "03-Jun-21 12:03:00 | Epoch 0012 | lr: 0.00513342 | Time(s): 40.8461\n",
      "03-Jun-21 12:03:00 | == train Loss <unroll: 7.9739 | vae : 35.6731 | kl : 0.0150>\n",
      "03-Jun-21 12:03:00 | == gmse <train: 6.8378 | val: 6.6642> \n",
      "03-Jun-21 12:03:41 | Epoch 0013 | lr: 0.00487675 | Time(s): 40.8583\n",
      "03-Jun-21 12:03:41 | == train Loss <unroll: 7.8523 | vae : 35.9756 | kl : 0.0145>\n",
      "03-Jun-21 12:03:41 | == gmse <train: 6.8166 | val: 6.5820> \n",
      "03-Jun-21 12:04:22 | Epoch 0014 | lr: 0.00463291 | Time(s): 40.8345\n",
      "03-Jun-21 12:04:22 | == train Loss <unroll: 7.3191 | vae : 33.6679 | kl : 0.0141>\n",
      "03-Jun-21 12:04:22 | == gmse <train: 6.3053 | val: 5.3174> \n",
      "03-Jun-21 12:05:03 | Epoch 0015 | lr: 0.00440127 | Time(s): 40.8395\n",
      "03-Jun-21 12:05:03 | == train Loss <unroll: 6.9376 | vae : 32.1065 | kl : 0.0137>\n",
      "03-Jun-21 12:05:03 | == gmse <train: 5.9468 | val: 5.7052> \n",
      "03-Jun-21 12:05:43 | Epoch 0016 | lr: 0.00418120 | Time(s): 40.8321\n",
      "03-Jun-21 12:05:43 | == train Loss <unroll: 6.4328 | vae : 29.6860 | kl : 0.0133>\n",
      "03-Jun-21 12:05:43 | == gmse <train: 5.4408 | val: 4.6764> \n",
      "03-Jun-21 12:06:24 | Epoch 0017 | lr: 0.00397214 | Time(s): 40.8393\n",
      "03-Jun-21 12:06:24 | == train Loss <unroll: 5.9698 | vae : 27.4016 | kl : 0.0129>\n",
      "03-Jun-21 12:06:24 | == gmse <train: 4.9732 | val: 4.4460> \n",
      "03-Jun-21 12:07:05 | Epoch 0018 | lr: 0.00377354 | Time(s): 40.8222\n",
      "03-Jun-21 12:07:05 | == train Loss <unroll: 5.5616 | vae : 25.3595 | kl : 0.0125>\n",
      "03-Jun-21 12:07:05 | == gmse <train: 4.5607 | val: 3.6647> \n",
      "03-Jun-21 12:07:46 | Epoch 0019 | lr: 0.00358486 | Time(s): 40.8210\n",
      "03-Jun-21 12:07:46 | == train Loss <unroll: 5.1548 | vae : 23.2456 | kl : 0.0121>\n",
      "03-Jun-21 12:07:46 | == gmse <train: 4.1450 | val: 3.6884> \n",
      "03-Jun-21 12:08:26 | Epoch 0020 | lr: 0.00340562 | Time(s): 40.8061\n",
      "03-Jun-21 12:08:26 | == train Loss <unroll: 4.8009 | vae : 21.3930 | kl : 0.0117>\n",
      "03-Jun-21 12:08:26 | == gmse <train: 3.7844 | val: 3.3252> \n",
      "03-Jun-21 12:09:07 | Epoch 0021 | lr: 0.00323534 | Time(s): 40.8053\n",
      "03-Jun-21 12:09:07 | == train Loss <unroll: 4.5035 | vae : 19.8383 | kl : 0.0114>\n",
      "03-Jun-21 12:09:07 | == gmse <train: 3.4835 | val: 2.8227> \n",
      "03-Jun-21 12:09:48 | Epoch 0022 | lr: 0.00307357 | Time(s): 40.7992\n",
      "03-Jun-21 12:09:48 | == train Loss <unroll: 4.2076 | vae : 18.2466 | kl : 0.0111>\n",
      "03-Jun-21 12:09:48 | == gmse <train: 3.1816 | val: 2.4498> \n",
      "03-Jun-21 12:10:28 | Epoch 0023 | lr: 0.00291989 | Time(s): 40.7937\n",
      "03-Jun-21 12:10:28 | == train Loss <unroll: 3.9820 | vae : 17.0636 | kl : 0.0108>\n",
      "03-Jun-21 12:10:28 | == gmse <train: 2.9558 | val: 2.2236> \n",
      "03-Jun-21 12:11:09 | Epoch 0024 | lr: 0.00277390 | Time(s): 40.7794\n",
      "03-Jun-21 12:11:09 | == train Loss <unroll: 3.7693 | vae : 15.9328 | kl : 0.0105>\n",
      "03-Jun-21 12:11:09 | == gmse <train: 2.7425 | val: 2.3312> \n",
      "03-Jun-21 12:11:50 | Epoch 0025 | lr: 0.00263520 | Time(s): 40.7884\n",
      "03-Jun-21 12:11:50 | == train Loss <unroll: 3.5882 | vae : 14.9786 | kl : 0.0102>\n",
      "03-Jun-21 12:11:50 | == gmse <train: 2.5628 | val: 1.8957> \n",
      "03-Jun-21 12:12:31 | Epoch 0026 | lr: 0.00250344 | Time(s): 40.7925\n",
      "03-Jun-21 12:12:31 | == train Loss <unroll: 3.3686 | vae : 13.7502 | kl : 0.0099>\n",
      "03-Jun-21 12:12:31 | == gmse <train: 2.3390 | val: 1.7863> \n",
      "03-Jun-21 12:13:11 | Epoch 0027 | lr: 0.00237827 | Time(s): 40.7848\n",
      "03-Jun-21 12:13:11 | == train Loss <unroll: 3.2008 | vae : 12.8367 | kl : 0.0096>\n",
      "03-Jun-21 12:13:11 | == gmse <train: 2.1716 | val: 1.5314> \n",
      "03-Jun-21 12:13:52 | Epoch 0028 | lr: 0.00225936 | Time(s): 40.7712\n",
      "03-Jun-21 12:13:52 | == train Loss <unroll: 3.0794 | vae : 12.2112 | kl : 0.0093>\n",
      "03-Jun-21 12:13:52 | == gmse <train: 2.0549 | val: 1.5019> \n",
      "03-Jun-21 12:14:32 | Epoch 0029 | lr: 0.00214639 | Time(s): 40.7624\n",
      "03-Jun-21 12:14:32 | == train Loss <unroll: 2.9209 | vae : 11.3260 | kl : 0.0091>\n",
      "03-Jun-21 12:14:32 | == gmse <train: 1.8960 | val: 1.3157> \n",
      "03-Jun-21 12:15:13 | Epoch 0030 | lr: 0.00203907 | Time(s): 40.7498\n",
      "03-Jun-21 12:15:13 | == train Loss <unroll: 2.8138 | vae : 10.7682 | kl : 0.0088>\n",
      "03-Jun-21 12:15:13 | == gmse <train: 1.7937 | val: 1.3077> \n",
      "03-Jun-21 12:15:53 | Epoch 0031 | lr: 0.00193711 | Time(s): 40.7457\n",
      "03-Jun-21 12:15:53 | == train Loss <unroll: 2.6902 | vae : 10.0900 | kl : 0.0086>\n",
      "03-Jun-21 12:15:53 | == gmse <train: 1.6725 | val: 1.2133> \n",
      "03-Jun-21 12:16:34 | Epoch 0032 | lr: 0.00184026 | Time(s): 40.7463\n",
      "03-Jun-21 12:16:34 | == train Loss <unroll: 2.5781 | vae : 9.4776 | kl : 0.0084>\n",
      "03-Jun-21 12:16:34 | == gmse <train: 1.5635 | val: 1.1358> \n",
      "03-Jun-21 12:18:36 | Epoch 0035 | lr: 0.00157779 | Time(s): 40.7460\n",
      "03-Jun-21 12:18:36 | == train Loss <unroll: 2.3066 | vae : 8.0269 | kl : 0.0078>\n",
      "03-Jun-21 12:18:36 | == gmse <train: 1.3061 | val: 0.8959> \n",
      "03-Jun-21 12:19:17 | Epoch 0036 | lr: 0.00149890 | Time(s): 40.7383\n",
      "03-Jun-21 12:19:17 | == train Loss <unroll: 2.2301 | vae : 7.6259 | kl : 0.0076>\n",
      "03-Jun-21 12:19:17 | == gmse <train: 1.2353 | val: 0.9513> \n",
      "03-Jun-21 12:19:58 | Epoch 0037 | lr: 0.00142396 | Time(s): 40.7383\n",
      "03-Jun-21 12:19:58 | == train Loss <unroll: 2.1700 | vae : 7.3315 | kl : 0.0074>\n",
      "03-Jun-21 12:19:58 | == gmse <train: 1.1824 | val: 0.8482> \n",
      "03-Jun-21 12:20:39 | Epoch 0038 | lr: 0.00135276 | Time(s): 40.7481\n",
      "03-Jun-21 12:20:39 | == train Loss <unroll: 2.0998 | vae : 6.9632 | kl : 0.0072>\n",
      "03-Jun-21 12:20:39 | == gmse <train: 1.1181 | val: 0.8539> \n",
      "03-Jun-21 12:21:19 | Epoch 0039 | lr: 0.00128512 | Time(s): 40.7486\n",
      "03-Jun-21 12:21:19 | == train Loss <unroll: 2.0437 | vae : 6.6867 | kl : 0.0071>\n",
      "03-Jun-21 12:21:19 | == gmse <train: 1.0690 | val: 0.8018> \n",
      "03-Jun-21 12:22:00 | Epoch 0040 | lr: 0.00122087 | Time(s): 40.7463\n",
      "03-Jun-21 12:22:00 | == train Loss <unroll: 1.9797 | vae : 6.3513 | kl : 0.0069>\n",
      "03-Jun-21 12:22:00 | == gmse <train: 1.0109 | val: 0.7484> \n",
      "03-Jun-21 12:22:41 | Epoch 0041 | lr: 0.00115982 | Time(s): 40.7475\n",
      "03-Jun-21 12:22:41 | == train Loss <unroll: 1.9251 | vae : 6.0771 | kl : 0.0068>\n",
      "03-Jun-21 12:22:41 | == gmse <train: 0.9631 | val: 0.7219> \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 12:23:21 | Epoch 0042 | lr: 0.00110183 | Time(s): 40.7423\n",
      "03-Jun-21 12:23:21 | == train Loss <unroll: 1.8833 | vae : 5.8874 | kl : 0.0066>\n",
      "03-Jun-21 12:23:21 | == gmse <train: 0.9290 | val: 0.7351> \n",
      "03-Jun-21 12:24:02 | Epoch 0043 | lr: 0.00104674 | Time(s): 40.7440\n",
      "03-Jun-21 12:24:02 | == train Loss <unroll: 1.8305 | vae : 5.6151 | kl : 0.0065>\n",
      "03-Jun-21 12:24:02 | == gmse <train: 0.8821 | val: 0.6594> \n",
      "03-Jun-21 12:24:43 | Epoch 0044 | lr: 0.00099440 | Time(s): 40.7363\n",
      "03-Jun-21 12:24:43 | == train Loss <unroll: 1.7904 | vae : 5.4278 | kl : 0.0064>\n",
      "03-Jun-21 12:24:43 | == gmse <train: 0.8490 | val: 0.6496> \n",
      "03-Jun-21 12:25:24 | Epoch 0045 | lr: 0.00094468 | Time(s): 40.7447\n",
      "03-Jun-21 12:25:24 | == train Loss <unroll: 1.7514 | vae : 5.2409 | kl : 0.0063>\n",
      "03-Jun-21 12:25:24 | == gmse <train: 0.8161 | val: 0.6713> \n",
      "03-Jun-21 12:26:04 | Epoch 0046 | lr: 0.00089745 | Time(s): 40.7434\n",
      "03-Jun-21 12:26:04 | == train Loss <unroll: 1.7093 | vae : 5.0235 | kl : 0.0061>\n",
      "03-Jun-21 12:26:04 | == gmse <train: 0.7788 | val: 0.6011> \n",
      "03-Jun-21 12:26:45 | Epoch 0047 | lr: 0.00085258 | Time(s): 40.7457\n",
      "03-Jun-21 12:26:45 | == train Loss <unroll: 1.6755 | vae : 4.8631 | kl : 0.0060>\n",
      "03-Jun-21 12:26:45 | == gmse <train: 0.7505 | val: 0.6289> \n",
      "03-Jun-21 12:27:26 | Epoch 0048 | lr: 0.00080995 | Time(s): 40.7436\n",
      "03-Jun-21 12:27:26 | == train Loss <unroll: 1.6452 | vae : 4.7275 | kl : 0.0059>\n",
      "03-Jun-21 12:27:26 | == gmse <train: 0.7262 | val: 0.5885> \n",
      "03-Jun-21 12:28:06 | Epoch 0049 | lr: 0.00076945 | Time(s): 40.7392\n",
      "03-Jun-21 12:28:06 | == train Loss <unroll: 1.6175 | vae : 4.6137 | kl : 0.0058>\n",
      "03-Jun-21 12:28:06 | == gmse <train: 0.7054 | val: 0.6120> \n",
      "03-Jun-21 12:28:47 | Epoch 0050 | lr: 0.00073098 | Time(s): 40.7389\n",
      "03-Jun-21 12:28:47 | == train Loss <unroll: 1.5910 | vae : 4.5104 | kl : 0.0057>\n",
      "03-Jun-21 12:28:47 | == gmse <train: 0.6862 | val: 0.5523> \n",
      "03-Jun-21 12:29:29 | Epoch 0051 | lr: 0.00069443 | Time(s): 40.7521\n",
      "03-Jun-21 12:29:29 | == train Loss <unroll: 1.5532 | vae : 4.3348 | kl : 0.0055>\n",
      "03-Jun-21 12:29:29 | == gmse <train: 0.6563 | val: 0.5553> \n",
      "03-Jun-21 12:30:09 | Epoch 0052 | lr: 0.00065971 | Time(s): 40.7542\n",
      "03-Jun-21 12:30:09 | == train Loss <unroll: 1.5161 | vae : 4.1853 | kl : 0.0054>\n",
      "03-Jun-21 12:30:09 | == gmse <train: 0.6305 | val: 0.5081> \n",
      "03-Jun-21 12:30:49 | Epoch 0053 | lr: 0.00062672 | Time(s): 40.7381\n",
      "03-Jun-21 12:30:49 | == train Loss <unroll: 1.4844 | vae : 4.0773 | kl : 0.0053>\n",
      "03-Jun-21 12:30:49 | == gmse <train: 0.6110 | val: 0.5168> \n",
      "03-Jun-21 12:31:30 | Epoch 0054 | lr: 0.00059539 | Time(s): 40.7366\n",
      "03-Jun-21 12:31:30 | == train Loss <unroll: 1.4599 | vae : 4.0149 | kl : 0.0052>\n",
      "03-Jun-21 12:31:30 | == gmse <train: 0.5983 | val: 0.5083> \n",
      "03-Jun-21 12:32:11 | Epoch 0055 | lr: 0.00056562 | Time(s): 40.7389\n",
      "03-Jun-21 12:32:11 | == train Loss <unroll: 1.4319 | vae : 3.9235 | kl : 0.0051>\n",
      "03-Jun-21 12:32:11 | == gmse <train: 0.5813 | val: 0.5042> \n",
      "03-Jun-21 12:32:52 | Epoch 0056 | lr: 0.00053734 | Time(s): 40.7400\n",
      "03-Jun-21 12:32:52 | == train Loss <unroll: 1.4020 | vae : 3.8149 | kl : 0.0049>\n",
      "03-Jun-21 12:32:52 | == gmse <train: 0.5619 | val: 0.4932> \n",
      "03-Jun-21 12:33:32 | Epoch 0057 | lr: 0.00051047 | Time(s): 40.7308\n",
      "03-Jun-21 12:33:32 | == train Loss <unroll: 1.3757 | vae : 3.7357 | kl : 0.0048>\n",
      "03-Jun-21 12:33:32 | == gmse <train: 0.5468 | val: 0.4749> \n",
      "03-Jun-21 12:34:12 | Epoch 0058 | lr: 0.00048495 | Time(s): 40.7247\n",
      "03-Jun-21 12:34:12 | == train Loss <unroll: 1.3525 | vae : 3.6904 | kl : 0.0047>\n",
      "03-Jun-21 12:34:12 | == gmse <train: 0.5366 | val: 0.4779> \n",
      "03-Jun-21 12:34:53 | Epoch 0059 | lr: 0.00046070 | Time(s): 40.7190\n",
      "03-Jun-21 12:34:53 | == train Loss <unroll: 1.3217 | vae : 3.5959 | kl : 0.0046>\n",
      "03-Jun-21 12:34:53 | == gmse <train: 0.5192 | val: 0.4481> \n",
      "03-Jun-21 12:35:33 | Epoch 0060 | lr: 0.00043766 | Time(s): 40.7160\n",
      "03-Jun-21 12:35:33 | == train Loss <unroll: 1.2881 | vae : 3.5099 | kl : 0.0045>\n",
      "03-Jun-21 12:35:33 | == gmse <train: 0.5031 | val: 0.4387> \n",
      "03-Jun-21 12:36:14 | Epoch 0061 | lr: 0.00041578 | Time(s): 40.7184\n",
      "03-Jun-21 12:36:14 | == train Loss <unroll: 1.2541 | vae : 3.4398 | kl : 0.0043>\n",
      "03-Jun-21 12:36:14 | == gmse <train: 0.4892 | val: 0.4373> \n",
      "03-Jun-21 12:36:54 | Epoch 0062 | lr: 0.00039499 | Time(s): 40.7122\n",
      "03-Jun-21 12:36:54 | == train Loss <unroll: 1.2190 | vae : 3.3692 | kl : 0.0042>\n",
      "03-Jun-21 12:36:54 | == gmse <train: 0.4752 | val: 0.4187> \n",
      "03-Jun-21 12:37:35 | Epoch 0063 | lr: 0.00037524 | Time(s): 40.7022\n",
      "03-Jun-21 12:37:35 | == train Loss <unroll: 1.1935 | vae : 3.3504 | kl : 0.0041>\n",
      "03-Jun-21 12:37:35 | == gmse <train: 0.4682 | val: 0.4147> \n",
      "03-Jun-21 12:38:15 | Epoch 0064 | lr: 0.00035648 | Time(s): 40.6960\n",
      "03-Jun-21 12:38:15 | == train Loss <unroll: 1.1690 | vae : 3.2784 | kl : 0.0040>\n",
      "03-Jun-21 12:38:15 | == gmse <train: 0.4540 | val: 0.4115> \n",
      "03-Jun-21 12:38:55 | Epoch 0065 | lr: 0.00033866 | Time(s): 40.6929\n",
      "03-Jun-21 12:38:55 | == train Loss <unroll: 1.1511 | vae : 3.2306 | kl : 0.0038>\n",
      "03-Jun-21 12:38:55 | == gmse <train: 0.4431 | val: 0.4064> \n",
      "03-Jun-21 12:39:36 | Epoch 0066 | lr: 0.00032172 | Time(s): 40.6900\n",
      "03-Jun-21 12:39:36 | == train Loss <unroll: 1.1352 | vae : 3.1820 | kl : 0.0037>\n",
      "03-Jun-21 12:39:36 | == gmse <train: 0.4321 | val: 0.3951> \n",
      "03-Jun-21 12:40:16 | Epoch 0067 | lr: 0.00030564 | Time(s): 40.6824\n",
      "03-Jun-21 12:40:16 | == train Loss <unroll: 1.1231 | vae : 3.1491 | kl : 0.0036>\n",
      "03-Jun-21 12:40:16 | == gmse <train: 0.4232 | val: 0.3918> \n",
      "03-Jun-21 12:40:56 | Epoch 0068 | lr: 0.00029035 | Time(s): 40.6752\n",
      "03-Jun-21 12:40:56 | == train Loss <unroll: 1.1092 | vae : 3.0890 | kl : 0.0034>\n",
      "03-Jun-21 12:40:56 | == gmse <train: 0.4108 | val: 0.3770> \n",
      "03-Jun-21 12:41:37 | Epoch 0069 | lr: 0.00027584 | Time(s): 40.6728\n",
      "03-Jun-21 12:41:37 | == train Loss <unroll: 1.1000 | vae : 3.0596 | kl : 0.0033>\n",
      "03-Jun-21 12:41:37 | == gmse <train: 0.4023 | val: 0.3721> \n",
      "03-Jun-21 12:42:17 | Epoch 0070 | lr: 0.00026205 | Time(s): 40.6665\n",
      "03-Jun-21 12:42:17 | == train Loss <unroll: 1.0884 | vae : 3.0077 | kl : 0.0032>\n",
      "03-Jun-21 12:42:17 | == gmse <train: 0.3910 | val: 0.3630> \n",
      "03-Jun-21 12:42:57 | Epoch 0071 | lr: 0.00024894 | Time(s): 40.6645\n",
      "03-Jun-21 12:42:57 | == train Loss <unroll: 1.0806 | vae : 2.9871 | kl : 0.0030>\n",
      "03-Jun-21 12:42:57 | == gmse <train: 0.3835 | val: 0.3534> \n",
      "03-Jun-21 12:43:38 | Epoch 0072 | lr: 0.00023650 | Time(s): 40.6585\n",
      "03-Jun-21 12:43:38 | == train Loss <unroll: 1.0717 | vae : 2.9574 | kl : 0.0029>\n",
      "03-Jun-21 12:43:38 | == gmse <train: 0.3749 | val: 0.3459> \n",
      "03-Jun-21 12:44:18 | Epoch 0073 | lr: 0.00022467 | Time(s): 40.6547\n",
      "03-Jun-21 12:44:18 | == train Loss <unroll: 1.0606 | vae : 2.9100 | kl : 0.0028>\n",
      "03-Jun-21 12:44:18 | == gmse <train: 0.3641 | val: 0.3391> \n",
      "03-Jun-21 12:44:58 | Epoch 0074 | lr: 0.00021344 | Time(s): 40.6497\n",
      "03-Jun-21 12:44:58 | == train Loss <unroll: 1.0516 | vae : 2.8766 | kl : 0.0027>\n",
      "03-Jun-21 12:44:58 | == gmse <train: 0.3550 | val: 0.3316> \n",
      "03-Jun-21 12:45:38 | Epoch 0075 | lr: 0.00020277 | Time(s): 40.6424\n",
      "03-Jun-21 12:45:38 | == train Loss <unroll: 1.0422 | vae : 2.8406 | kl : 0.0025>\n",
      "03-Jun-21 12:45:38 | == gmse <train: 0.3456 | val: 0.3315> \n",
      "03-Jun-21 12:46:19 | Epoch 0076 | lr: 0.00019263 | Time(s): 40.6382\n",
      "03-Jun-21 12:46:19 | == train Loss <unroll: 1.0355 | vae : 2.8281 | kl : 0.0024>\n",
      "03-Jun-21 12:46:19 | == gmse <train: 0.3389 | val: 0.3218> \n",
      "03-Jun-21 12:46:59 | Epoch 0077 | lr: 0.00018300 | Time(s): 40.6326\n",
      "03-Jun-21 12:46:59 | == train Loss <unroll: 1.0264 | vae : 2.7953 | kl : 0.0023>\n",
      "03-Jun-21 12:46:59 | == gmse <train: 0.3299 | val: 0.3096> \n",
      "03-Jun-21 12:47:39 | Epoch 0078 | lr: 0.00017385 | Time(s): 40.6298\n",
      "03-Jun-21 12:47:39 | == train Loss <unroll: 1.0183 | vae : 2.7707 | kl : 0.0022>\n",
      "03-Jun-21 12:47:39 | == gmse <train: 0.3218 | val: 0.3037> \n",
      "03-Jun-21 12:48:20 | Epoch 0079 | lr: 0.00016515 | Time(s): 40.6229\n",
      "03-Jun-21 12:48:20 | == train Loss <unroll: 1.0107 | vae : 2.7485 | kl : 0.0020>\n",
      "03-Jun-21 12:48:20 | == gmse <train: 0.3140 | val: 0.2982> \n",
      "03-Jun-21 12:49:00 | Epoch 0080 | lr: 0.00015690 | Time(s): 40.6179\n",
      "03-Jun-21 12:49:00 | == train Loss <unroll: 1.0030 | vae : 2.7266 | kl : 0.0019>\n",
      "03-Jun-21 12:49:00 | == gmse <train: 0.3062 | val: 0.2961> \n",
      "03-Jun-21 12:49:40 | Epoch 0081 | lr: 0.00014905 | Time(s): 40.6095\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 12:49:40 | == train Loss <unroll: 0.9957 | vae : 2.7075 | kl : 0.0018>\n",
      "03-Jun-21 12:49:40 | == gmse <train: 0.2986 | val: 0.2879> \n",
      "03-Jun-21 12:50:20 | Epoch 0082 | lr: 0.00014160 | Time(s): 40.6075\n",
      "03-Jun-21 12:50:20 | == train Loss <unroll: 0.9875 | vae : 2.6799 | kl : 0.0017>\n",
      "03-Jun-21 12:50:20 | == gmse <train: 0.2903 | val: 0.2789> \n",
      "03-Jun-21 12:51:01 | Epoch 0083 | lr: 0.00013452 | Time(s): 40.6054\n",
      "03-Jun-21 12:51:01 | == train Loss <unroll: 0.9803 | vae : 2.6609 | kl : 0.0016>\n",
      "03-Jun-21 12:51:01 | == gmse <train: 0.2828 | val: 0.2702> \n",
      "03-Jun-21 12:51:41 | Epoch 0084 | lr: 0.00012779 | Time(s): 40.6052\n",
      "03-Jun-21 12:51:41 | == train Loss <unroll: 0.9738 | vae : 2.6468 | kl : 0.0015>\n",
      "03-Jun-21 12:51:41 | == gmse <train: 0.2758 | val: 0.2641> \n",
      "03-Jun-21 12:52:22 | Epoch 0085 | lr: 0.00012140 | Time(s): 40.6047\n",
      "03-Jun-21 12:52:22 | == train Loss <unroll: 0.9666 | vae : 2.6292 | kl : 0.0014>\n",
      "03-Jun-21 12:52:22 | == gmse <train: 0.2685 | val: 0.2600> \n",
      "03-Jun-21 12:53:02 | Epoch 0086 | lr: 0.00011533 | Time(s): 40.6012\n",
      "03-Jun-21 12:53:02 | == train Loss <unroll: 0.9594 | vae : 2.6121 | kl : 0.0013>\n",
      "03-Jun-21 12:53:02 | == gmse <train: 0.2612 | val: 0.2490> \n",
      "03-Jun-21 12:53:43 | Epoch 0087 | lr: 0.00010957 | Time(s): 40.5997\n",
      "03-Jun-21 12:53:43 | == train Loss <unroll: 0.9523 | vae : 2.5911 | kl : 0.0012>\n",
      "03-Jun-21 12:53:43 | == gmse <train: 0.2536 | val: 0.2442> \n",
      "03-Jun-21 12:54:23 | Epoch 0088 | lr: 0.00010409 | Time(s): 40.5969\n",
      "03-Jun-21 12:54:23 | == train Loss <unroll: 0.9459 | vae : 2.5767 | kl : 0.0011>\n",
      "03-Jun-21 12:54:23 | == gmse <train: 0.2467 | val: 0.2370> \n",
      "03-Jun-21 12:55:03 | Epoch 0089 | lr: 0.00009888 | Time(s): 40.5957\n",
      "03-Jun-21 12:55:03 | == train Loss <unroll: 0.9392 | vae : 2.5643 | kl : 0.0010>\n",
      "03-Jun-21 12:55:03 | == gmse <train: 0.2399 | val: 0.2307> \n",
      "03-Jun-21 12:55:44 | Epoch 0090 | lr: 0.00009394 | Time(s): 40.5930\n",
      "03-Jun-21 12:55:44 | == train Loss <unroll: 0.9327 | vae : 2.5520 | kl : 0.0009>\n",
      "03-Jun-21 12:55:44 | == gmse <train: 0.2330 | val: 0.2237> \n",
      "03-Jun-21 12:56:24 | Epoch 0091 | lr: 0.00008924 | Time(s): 40.5853\n",
      "03-Jun-21 12:56:24 | == train Loss <unroll: 0.9264 | vae : 2.5449 | kl : 0.0008>\n",
      "03-Jun-21 12:56:24 | == gmse <train: 0.2266 | val: 0.2180> \n",
      "03-Jun-21 12:57:04 | Epoch 0092 | lr: 0.00008478 | Time(s): 40.5797\n",
      "03-Jun-21 12:57:04 | == train Loss <unroll: 0.9196 | vae : 2.5277 | kl : 0.0007>\n",
      "03-Jun-21 12:57:04 | == gmse <train: 0.2194 | val: 0.2098> \n",
      "03-Jun-21 12:57:44 | Epoch 0093 | lr: 0.00008054 | Time(s): 40.5769\n",
      "03-Jun-21 12:57:44 | == train Loss <unroll: 0.9133 | vae : 2.5178 | kl : 0.0007>\n",
      "03-Jun-21 12:57:44 | == gmse <train: 0.2128 | val: 0.2050> \n",
      "03-Jun-21 12:58:25 | Epoch 0094 | lr: 0.00007651 | Time(s): 40.5768\n",
      "03-Jun-21 12:58:25 | == train Loss <unroll: 0.9065 | vae : 2.5001 | kl : 0.0006>\n",
      "03-Jun-21 12:58:25 | == gmse <train: 0.2055 | val: 0.1995> \n",
      "03-Jun-21 12:59:05 | Epoch 0095 | lr: 0.00007269 | Time(s): 40.5772\n",
      "03-Jun-21 12:59:05 | == train Loss <unroll: 0.9006 | vae : 2.4959 | kl : 0.0005>\n",
      "03-Jun-21 12:59:05 | == gmse <train: 0.1993 | val: 0.1920> \n",
      "03-Jun-21 12:59:46 | Epoch 0096 | lr: 0.00006905 | Time(s): 40.5755\n",
      "03-Jun-21 12:59:46 | == train Loss <unroll: 0.8938 | vae : 2.4808 | kl : 0.0005>\n",
      "03-Jun-21 12:59:46 | == gmse <train: 0.1924 | val: 0.1873> \n",
      "03-Jun-21 13:00:26 | Epoch 0097 | lr: 0.00006560 | Time(s): 40.5721\n",
      "03-Jun-21 13:00:26 | == train Loss <unroll: 0.8878 | vae : 2.4706 | kl : 0.0004>\n",
      "03-Jun-21 13:00:26 | == gmse <train: 0.1858 | val: 0.1813> \n",
      "03-Jun-21 13:01:06 | Epoch 0098 | lr: 0.00006232 | Time(s): 40.5681\n",
      "03-Jun-21 13:01:06 | == train Loss <unroll: 0.8816 | vae : 2.4624 | kl : 0.0004>\n",
      "03-Jun-21 13:01:06 | == gmse <train: 0.1795 | val: 0.1736> \n",
      "03-Jun-21 13:01:46 | Epoch 0099 | lr: 0.00005921 | Time(s): 40.5645\n",
      "03-Jun-21 13:01:46 | == train Loss <unroll: 0.8756 | vae : 2.4524 | kl : 0.0003>\n",
      "03-Jun-21 13:01:46 | == gmse <train: 0.1731 | val: 0.1674> \n",
      "03-Jun-21 13:02:27 | Epoch 0100 | lr: 0.00005625 | Time(s): 40.5650\n",
      "03-Jun-21 13:02:27 | == train Loss <unroll: 0.8694 | vae : 2.4429 | kl : 0.0003>\n",
      "03-Jun-21 13:02:27 | == gmse <train: 0.1668 | val: 0.1610> \n",
      "03-Jun-21 13:03:08 | Epoch 0101 | lr: 0.00005343 | Time(s): 40.5675\n",
      "03-Jun-21 13:03:08 | == train Loss <unroll: 0.8639 | vae : 2.4368 | kl : 0.0003>\n",
      "03-Jun-21 13:03:08 | == gmse <train: 0.1609 | val: 0.1550> \n",
      "03-Jun-21 13:03:48 | Epoch 0102 | lr: 0.00005076 | Time(s): 40.5633\n",
      "03-Jun-21 13:03:48 | == train Loss <unroll: 0.8580 | vae : 2.4263 | kl : 0.0002>\n",
      "03-Jun-21 13:03:48 | == gmse <train: 0.1548 | val: 0.1514> \n",
      "03-Jun-21 13:04:28 | Epoch 0103 | lr: 0.00004822 | Time(s): 40.5623\n",
      "03-Jun-21 13:04:28 | == train Loss <unroll: 0.8527 | vae : 2.4217 | kl : 0.0002>\n",
      "03-Jun-21 13:04:28 | == gmse <train: 0.1493 | val: 0.1438> \n",
      "03-Jun-21 13:05:09 | Epoch 0104 | lr: 0.00004581 | Time(s): 40.5620\n",
      "03-Jun-21 13:05:09 | == train Loss <unroll: 0.8472 | vae : 2.4120 | kl : 0.0002>\n",
      "03-Jun-21 13:05:09 | == gmse <train: 0.1436 | val: 0.1395> \n",
      "03-Jun-21 13:05:49 | Epoch 0105 | lr: 0.00004352 | Time(s): 40.5620\n",
      "03-Jun-21 13:05:49 | == train Loss <unroll: 0.8421 | vae : 2.4045 | kl : 0.0001>\n",
      "03-Jun-21 13:05:49 | == gmse <train: 0.1383 | val: 0.1338> \n",
      "03-Jun-21 13:06:30 | Epoch 0106 | lr: 0.00004135 | Time(s): 40.5614\n",
      "03-Jun-21 13:06:30 | == train Loss <unroll: 0.8372 | vae : 2.4007 | kl : 0.0001>\n",
      "03-Jun-21 13:06:30 | == gmse <train: 0.1334 | val: 0.1281> \n",
      "03-Jun-21 13:07:10 | Epoch 0107 | lr: 0.00003928 | Time(s): 40.5574\n",
      "03-Jun-21 13:07:10 | == train Loss <unroll: 0.8325 | vae : 2.3921 | kl : 0.0001>\n",
      "03-Jun-21 13:07:10 | == gmse <train: 0.1284 | val: 0.1252> \n",
      "03-Jun-21 13:07:51 | Epoch 0108 | lr: 0.00003731 | Time(s): 40.5571\n",
      "03-Jun-21 13:07:51 | == train Loss <unroll: 0.8280 | vae : 2.3858 | kl : 0.0001>\n",
      "03-Jun-21 13:07:51 | == gmse <train: 0.1239 | val: 0.1195> \n",
      "03-Jun-21 13:08:31 | Epoch 0109 | lr: 0.00003545 | Time(s): 40.5577\n",
      "03-Jun-21 13:08:31 | == train Loss <unroll: 0.8237 | vae : 2.3792 | kl : 0.0001>\n",
      "03-Jun-21 13:08:31 | == gmse <train: 0.1195 | val: 0.1155> \n",
      "03-Jun-21 13:09:12 | Epoch 0110 | lr: 0.00003368 | Time(s): 40.5571\n",
      "03-Jun-21 13:09:12 | == train Loss <unroll: 0.8197 | vae : 2.3739 | kl : 0.0001>\n",
      "03-Jun-21 13:09:12 | == gmse <train: 0.1155 | val: 0.1119> \n",
      "03-Jun-21 13:09:52 | Epoch 0111 | lr: 0.00003199 | Time(s): 40.5521\n",
      "03-Jun-21 13:09:52 | == train Loss <unroll: 0.8160 | vae : 2.3676 | kl : 0.0001>\n",
      "03-Jun-21 13:09:52 | == gmse <train: 0.1116 | val: 0.1090> \n",
      "03-Jun-21 13:10:32 | Epoch 0112 | lr: 0.00003039 | Time(s): 40.5493\n",
      "03-Jun-21 13:10:32 | == train Loss <unroll: 0.8125 | vae : 2.3660 | kl : 0.0000>\n",
      "03-Jun-21 13:10:32 | == gmse <train: 0.1082 | val: 0.1041> \n",
      "03-Jun-21 13:11:12 | Epoch 0113 | lr: 0.00002887 | Time(s): 40.5453\n",
      "03-Jun-21 13:11:12 | == train Loss <unroll: 0.8092 | vae : 2.3589 | kl : 0.0000>\n",
      "03-Jun-21 13:11:12 | == gmse <train: 0.1048 | val: 0.1019> \n",
      "03-Jun-21 13:11:53 | Epoch 0114 | lr: 0.00002743 | Time(s): 40.5446\n",
      "03-Jun-21 13:11:53 | == train Loss <unroll: 0.8061 | vae : 2.3528 | kl : 0.0000>\n",
      "03-Jun-21 13:11:53 | == gmse <train: 0.1017 | val: 0.0977> \n",
      "03-Jun-21 13:12:33 | Epoch 0115 | lr: 0.00002606 | Time(s): 40.5418\n",
      "03-Jun-21 13:12:33 | == train Loss <unroll: 0.8033 | vae : 2.3481 | kl : 0.0000>\n",
      "03-Jun-21 13:12:33 | == gmse <train: 0.0988 | val: 0.0962> \n",
      "03-Jun-21 13:13:13 | Epoch 0116 | lr: 0.00002475 | Time(s): 40.5406\n",
      "03-Jun-21 13:13:13 | == train Loss <unroll: 0.8005 | vae : 2.3437 | kl : 0.0000>\n",
      "03-Jun-21 13:13:13 | == gmse <train: 0.0962 | val: 0.0936> \n",
      "03-Jun-21 13:13:54 | Epoch 0117 | lr: 0.00002352 | Time(s): 40.5397\n",
      "03-Jun-21 13:13:54 | == train Loss <unroll: 0.7983 | vae : 2.3406 | kl : 0.0000>\n",
      "03-Jun-21 13:13:54 | == gmse <train: 0.0938 | val: 0.0914> \n",
      "03-Jun-21 13:14:33 | Epoch 0118 | lr: 0.00002234 | Time(s): 40.5342\n",
      "03-Jun-21 13:14:33 | == train Loss <unroll: 0.7959 | vae : 2.3351 | kl : 0.0000>\n",
      "03-Jun-21 13:14:33 | == gmse <train: 0.0916 | val: 0.0894> \n",
      "03-Jun-21 13:15:14 | Epoch 0119 | lr: 0.00002122 | Time(s): 40.5341\n",
      "03-Jun-21 13:15:14 | == train Loss <unroll: 0.7938 | vae : 2.3315 | kl : 0.0000>\n",
      "03-Jun-21 13:15:14 | == gmse <train: 0.0896 | val: 0.0868> \n",
      "03-Jun-21 13:15:55 | Epoch 0120 | lr: 0.00002016 | Time(s): 40.5339\n",
      "03-Jun-21 13:15:55 | == train Loss <unroll: 0.7919 | vae : 2.3279 | kl : 0.0000>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 13:15:55 | == gmse <train: 0.0878 | val: 0.0845> \n",
      "03-Jun-21 13:16:35 | Epoch 0121 | lr: 0.00001915 | Time(s): 40.5334\n",
      "03-Jun-21 13:16:35 | == train Loss <unroll: 0.7903 | vae : 2.3248 | kl : 0.0000>\n",
      "03-Jun-21 13:16:35 | == gmse <train: 0.0861 | val: 0.0819> \n",
      "03-Jun-21 13:17:15 | Epoch 0122 | lr: 0.00001820 | Time(s): 40.5313\n",
      "03-Jun-21 13:17:15 | == train Loss <unroll: 0.7887 | vae : 2.3191 | kl : 0.0000>\n",
      "03-Jun-21 13:17:15 | == gmse <train: 0.0846 | val: 0.0818> \n",
      "03-Jun-21 13:17:56 | Epoch 0123 | lr: 0.00001729 | Time(s): 40.5315\n",
      "03-Jun-21 13:17:56 | == train Loss <unroll: 0.7872 | vae : 2.3158 | kl : 0.0000>\n",
      "03-Jun-21 13:17:56 | == gmse <train: 0.0832 | val: 0.0798> \n",
      "03-Jun-21 13:18:36 | Epoch 0124 | lr: 0.00001642 | Time(s): 40.5300\n",
      "03-Jun-21 13:18:36 | == train Loss <unroll: 0.7859 | vae : 2.3133 | kl : 0.0000>\n",
      "03-Jun-21 13:18:36 | == gmse <train: 0.0819 | val: 0.0784> \n",
      "03-Jun-21 13:19:17 | Epoch 0125 | lr: 0.00001560 | Time(s): 40.5289\n",
      "03-Jun-21 13:19:17 | == train Loss <unroll: 0.7847 | vae : 2.3095 | kl : 0.0000>\n",
      "03-Jun-21 13:19:17 | == gmse <train: 0.0808 | val: 0.0772> \n",
      "03-Jun-21 13:19:57 | Epoch 0126 | lr: 0.00001482 | Time(s): 40.5302\n",
      "03-Jun-21 13:19:57 | == train Loss <unroll: 0.7836 | vae : 2.3068 | kl : 0.0000>\n",
      "03-Jun-21 13:19:57 | == gmse <train: 0.0798 | val: 0.0758> \n",
      "03-Jun-21 13:20:38 | Epoch 0127 | lr: 0.00001408 | Time(s): 40.5293\n",
      "03-Jun-21 13:20:38 | == train Loss <unroll: 0.7826 | vae : 2.3022 | kl : 0.0000>\n",
      "03-Jun-21 13:20:38 | == gmse <train: 0.0788 | val: 0.0750> \n",
      "03-Jun-21 13:21:18 | Epoch 0128 | lr: 0.00001338 | Time(s): 40.5280\n",
      "03-Jun-21 13:21:18 | == train Loss <unroll: 0.7817 | vae : 2.3003 | kl : 0.0000>\n",
      "03-Jun-21 13:21:18 | == gmse <train: 0.0779 | val: 0.0748> \n",
      "03-Jun-21 13:21:59 | Epoch 0129 | lr: 0.00001271 | Time(s): 40.5290\n",
      "03-Jun-21 13:21:59 | == train Loss <unroll: 0.7809 | vae : 2.2982 | kl : 0.0000>\n",
      "03-Jun-21 13:21:59 | == gmse <train: 0.0772 | val: 0.0737> \n",
      "03-Jun-21 13:22:39 | Epoch 0130 | lr: 0.00001207 | Time(s): 40.5270\n",
      "03-Jun-21 13:22:39 | == train Loss <unroll: 0.7801 | vae : 2.2937 | kl : 0.0000>\n",
      "03-Jun-21 13:22:39 | == gmse <train: 0.0765 | val: 0.0725> \n",
      "03-Jun-21 13:23:19 | Epoch 0131 | lr: 0.00001147 | Time(s): 40.5245\n",
      "03-Jun-21 13:23:19 | == train Loss <unroll: 0.7794 | vae : 2.2923 | kl : 0.0000>\n",
      "03-Jun-21 13:23:19 | == gmse <train: 0.0758 | val: 0.0721> \n",
      "03-Jun-21 13:24:00 | Epoch 0132 | lr: 0.00001090 | Time(s): 40.5233\n",
      "03-Jun-21 13:24:00 | == train Loss <unroll: 0.7787 | vae : 2.2902 | kl : 0.0000>\n",
      "03-Jun-21 13:24:00 | == gmse <train: 0.0753 | val: 0.0713> \n",
      "03-Jun-21 13:24:40 | Epoch 0133 | lr: 0.00001035 | Time(s): 40.5219\n",
      "03-Jun-21 13:24:40 | == train Loss <unroll: 0.7782 | vae : 2.2872 | kl : 0.0000>\n",
      "03-Jun-21 13:24:40 | == gmse <train: 0.0748 | val: 0.0700> \n",
      "03-Jun-21 13:25:20 | Epoch 0134 | lr: 0.00000983 | Time(s): 40.5206\n",
      "03-Jun-21 13:25:20 | == train Loss <unroll: 0.7777 | vae : 2.2850 | kl : 0.0000>\n",
      "03-Jun-21 13:25:20 | == gmse <train: 0.0743 | val: 0.0701> \n",
      "03-Jun-21 13:26:01 | Epoch 0135 | lr: 0.00000934 | Time(s): 40.5210\n",
      "03-Jun-21 13:26:01 | == train Loss <unroll: 0.7772 | vae : 2.2820 | kl : 0.0000>\n",
      "03-Jun-21 13:26:01 | == gmse <train: 0.0739 | val: 0.0697> \n",
      "03-Jun-21 13:26:41 | Epoch 0136 | lr: 0.00000887 | Time(s): 40.5201\n",
      "03-Jun-21 13:26:41 | == train Loss <unroll: 0.7768 | vae : 2.2811 | kl : 0.0000>\n",
      "03-Jun-21 13:26:41 | == gmse <train: 0.0735 | val: 0.0693> \n",
      "03-Jun-21 13:27:22 | Epoch 0137 | lr: 0.00000843 | Time(s): 40.5186\n",
      "03-Jun-21 13:27:22 | == train Loss <unroll: 0.7764 | vae : 2.2773 | kl : 0.0000>\n",
      "03-Jun-21 13:27:22 | == gmse <train: 0.0731 | val: 0.0692> \n",
      "03-Jun-21 13:28:02 | Epoch 0138 | lr: 0.00000801 | Time(s): 40.5175\n",
      "03-Jun-21 13:28:02 | == train Loss <unroll: 0.7760 | vae : 2.2773 | kl : 0.0000>\n",
      "03-Jun-21 13:28:02 | == gmse <train: 0.0728 | val: 0.0684> \n",
      "03-Jun-21 13:28:42 | Epoch 0139 | lr: 0.00000761 | Time(s): 40.5169\n",
      "03-Jun-21 13:28:42 | == train Loss <unroll: 0.7757 | vae : 2.2751 | kl : 0.0000>\n",
      "03-Jun-21 13:28:42 | == gmse <train: 0.0725 | val: 0.0687> \n",
      "03-Jun-21 13:29:23 | Epoch 0140 | lr: 0.00000723 | Time(s): 40.5151\n",
      "03-Jun-21 13:29:23 | == train Loss <unroll: 0.7754 | vae : 2.2731 | kl : 0.0000>\n",
      "03-Jun-21 13:29:23 | == gmse <train: 0.0723 | val: 0.0676> \n",
      "03-Jun-21 13:30:03 | Epoch 0141 | lr: 0.00000687 | Time(s): 40.5147\n",
      "03-Jun-21 13:30:03 | == train Loss <unroll: 0.7752 | vae : 2.2719 | kl : 0.0000>\n",
      "03-Jun-21 13:30:03 | == gmse <train: 0.0721 | val: 0.0674> \n",
      "03-Jun-21 13:30:43 | Epoch 0142 | lr: 0.00000652 | Time(s): 40.5138\n",
      "03-Jun-21 13:30:43 | == train Loss <unroll: 0.7749 | vae : 2.2693 | kl : 0.0000>\n",
      "03-Jun-21 13:30:43 | == gmse <train: 0.0718 | val: 0.0676> \n",
      "03-Jun-21 13:31:24 | Epoch 0143 | lr: 0.00000620 | Time(s): 40.5142\n",
      "03-Jun-21 13:31:24 | == train Loss <unroll: 0.7747 | vae : 2.2675 | kl : 0.0000>\n",
      "03-Jun-21 13:31:24 | == gmse <train: 0.0717 | val: 0.0670> \n",
      "03-Jun-21 13:32:04 | Epoch 0144 | lr: 0.00000589 | Time(s): 40.5127\n",
      "03-Jun-21 13:32:04 | == train Loss <unroll: 0.7745 | vae : 2.2663 | kl : 0.0000>\n",
      "03-Jun-21 13:32:04 | == gmse <train: 0.0715 | val: 0.0668> \n",
      "03-Jun-21 13:32:45 | Epoch 0145 | lr: 0.00000559 | Time(s): 40.5126\n",
      "03-Jun-21 13:32:45 | == train Loss <unroll: 0.7743 | vae : 2.2654 | kl : 0.0000>\n",
      "03-Jun-21 13:32:45 | == gmse <train: 0.0713 | val: 0.0661> \n",
      "03-Jun-21 13:33:25 | Epoch 0146 | lr: 0.00000531 | Time(s): 40.5133\n",
      "03-Jun-21 13:33:25 | == train Loss <unroll: 0.7741 | vae : 2.2626 | kl : 0.0000>\n",
      "03-Jun-21 13:33:25 | == gmse <train: 0.0712 | val: 0.0663> \n",
      "03-Jun-21 13:34:06 | Epoch 0147 | lr: 0.00000505 | Time(s): 40.5110\n",
      "03-Jun-21 13:34:06 | == train Loss <unroll: 0.7739 | vae : 2.2618 | kl : 0.0000>\n",
      "03-Jun-21 13:34:06 | == gmse <train: 0.0710 | val: 0.0659> \n",
      "03-Jun-21 13:34:46 | Epoch 0148 | lr: 0.00000480 | Time(s): 40.5085\n",
      "03-Jun-21 13:34:46 | == train Loss <unroll: 0.7738 | vae : 2.2606 | kl : 0.0000>\n",
      "03-Jun-21 13:34:46 | == gmse <train: 0.0709 | val: 0.0658> \n",
      "03-Jun-21 13:35:26 | Epoch 0149 | lr: 0.00000456 | Time(s): 40.5059\n",
      "03-Jun-21 13:35:26 | == train Loss <unroll: 0.7736 | vae : 2.2595 | kl : 0.0000>\n",
      "03-Jun-21 13:35:26 | == gmse <train: 0.0708 | val: 0.0659> \n",
      "03-Jun-21 13:36:06 | Epoch 0150 | lr: 0.00000433 | Time(s): 40.5045\n",
      "03-Jun-21 13:36:06 | == train Loss <unroll: 0.7735 | vae : 2.2583 | kl : 0.0000>\n",
      "03-Jun-21 13:36:06 | == gmse <train: 0.0707 | val: 0.0657> \n",
      "03-Jun-21 13:36:46 | Epoch 0151 | lr: 0.00000411 | Time(s): 40.5021\n",
      "03-Jun-21 13:36:46 | == train Loss <unroll: 0.7734 | vae : 2.2568 | kl : 0.0000>\n",
      "03-Jun-21 13:36:46 | == gmse <train: 0.0706 | val: 0.0657> \n",
      "03-Jun-21 13:37:26 | Epoch 0152 | lr: 0.00000391 | Time(s): 40.4991\n",
      "03-Jun-21 13:37:26 | == train Loss <unroll: 0.7733 | vae : 2.2559 | kl : 0.0000>\n",
      "03-Jun-21 13:37:26 | == gmse <train: 0.0705 | val: 0.0655> \n",
      "03-Jun-21 13:38:07 | Epoch 0153 | lr: 0.00000371 | Time(s): 40.4970\n",
      "03-Jun-21 13:38:07 | == train Loss <unroll: 0.7732 | vae : 2.2548 | kl : 0.0000>\n",
      "03-Jun-21 13:38:07 | == gmse <train: 0.0705 | val: 0.0652> \n",
      "03-Jun-21 13:38:47 | Epoch 0154 | lr: 0.00000352 | Time(s): 40.4953\n",
      "03-Jun-21 13:38:47 | == train Loss <unroll: 0.7731 | vae : 2.2543 | kl : 0.0000>\n",
      "03-Jun-21 13:38:47 | == gmse <train: 0.0704 | val: 0.0646> \n",
      "03-Jun-21 13:39:27 | Epoch 0155 | lr: 0.00000335 | Time(s): 40.4936\n",
      "03-Jun-21 13:39:27 | == train Loss <unroll: 0.7730 | vae : 2.2526 | kl : 0.0000>\n",
      "03-Jun-21 13:39:27 | == gmse <train: 0.0703 | val: 0.0650> \n",
      "03-Jun-21 13:40:07 | Epoch 0156 | lr: 0.00000318 | Time(s): 40.4927\n",
      "03-Jun-21 13:40:07 | == train Loss <unroll: 0.7729 | vae : 2.2518 | kl : 0.0000>\n",
      "03-Jun-21 13:40:07 | == gmse <train: 0.0702 | val: 0.0649> \n",
      "03-Jun-21 13:40:48 | Epoch 0157 | lr: 0.00000302 | Time(s): 40.4915\n",
      "03-Jun-21 13:40:48 | == train Loss <unroll: 0.7729 | vae : 2.2510 | kl : 0.0000>\n",
      "03-Jun-21 13:40:48 | == gmse <train: 0.0702 | val: 0.0647> \n",
      "03-Jun-21 13:41:28 | Epoch 0158 | lr: 0.00000287 | Time(s): 40.4909\n",
      "03-Jun-21 13:41:28 | == train Loss <unroll: 0.7728 | vae : 2.2499 | kl : 0.0000>\n",
      "03-Jun-21 13:41:28 | == gmse <train: 0.0701 | val: 0.0648> \n",
      "03-Jun-21 13:42:09 | Epoch 0159 | lr: 0.00000273 | Time(s): 40.4906\n",
      "03-Jun-21 13:42:09 | == train Loss <unroll: 0.7727 | vae : 2.2492 | kl : 0.0000>\n",
      "03-Jun-21 13:42:09 | == gmse <train: 0.0701 | val: 0.0648> \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 13:42:49 | Epoch 0160 | lr: 0.00000259 | Time(s): 40.4885\n",
      "03-Jun-21 13:42:49 | == train Loss <unroll: 0.7727 | vae : 2.2484 | kl : 0.0000>\n",
      "03-Jun-21 13:42:49 | == gmse <train: 0.0700 | val: 0.0647> \n",
      "03-Jun-21 13:43:29 | Epoch 0161 | lr: 0.00000246 | Time(s): 40.4863\n",
      "03-Jun-21 13:43:29 | == train Loss <unroll: 0.7726 | vae : 2.2477 | kl : 0.0000>\n",
      "03-Jun-21 13:43:29 | == gmse <train: 0.0700 | val: 0.0648> \n",
      "03-Jun-21 13:44:09 | Epoch 0162 | lr: 0.00000234 | Time(s): 40.4862\n",
      "03-Jun-21 13:44:09 | == train Loss <unroll: 0.7726 | vae : 2.2468 | kl : 0.0000>\n",
      "03-Jun-21 13:44:09 | == gmse <train: 0.0700 | val: 0.0645> \n",
      "03-Jun-21 13:44:50 | Epoch 0163 | lr: 0.00000222 | Time(s): 40.4866\n",
      "03-Jun-21 13:44:50 | == train Loss <unroll: 0.7725 | vae : 2.2462 | kl : 0.0000>\n",
      "03-Jun-21 13:44:50 | == gmse <train: 0.0699 | val: 0.0645> \n",
      "03-Jun-21 13:45:30 | Epoch 0164 | lr: 0.00000211 | Time(s): 40.4833\n",
      "03-Jun-21 13:45:30 | == train Loss <unroll: 0.7724 | vae : 2.2456 | kl : 0.0000>\n",
      "03-Jun-21 13:45:30 | == gmse <train: 0.0699 | val: 0.0643> \n",
      "03-Jun-21 13:46:10 | Epoch 0165 | lr: 0.00000201 | Time(s): 40.4822\n",
      "03-Jun-21 13:46:10 | == train Loss <unroll: 0.7724 | vae : 2.2449 | kl : 0.0000>\n",
      "03-Jun-21 13:46:10 | == gmse <train: 0.0698 | val: 0.0646> \n",
      "03-Jun-21 13:46:51 | Epoch 0166 | lr: 0.00000190 | Time(s): 40.4830\n",
      "03-Jun-21 13:46:51 | == train Loss <unroll: 0.7724 | vae : 2.2444 | kl : 0.0000>\n",
      "03-Jun-21 13:46:51 | == gmse <train: 0.0698 | val: 0.0642> \n",
      "03-Jun-21 13:47:31 | Epoch 0167 | lr: 0.00000181 | Time(s): 40.4833\n",
      "03-Jun-21 13:47:31 | == train Loss <unroll: 0.7723 | vae : 2.2436 | kl : 0.0000>\n",
      "03-Jun-21 13:47:31 | == gmse <train: 0.0698 | val: 0.0645> \n",
      "03-Jun-21 13:48:12 | Epoch 0168 | lr: 0.00000172 | Time(s): 40.4821\n",
      "03-Jun-21 13:48:12 | == train Loss <unroll: 0.7723 | vae : 2.2435 | kl : 0.0000>\n",
      "03-Jun-21 13:48:12 | == gmse <train: 0.0698 | val: 0.0644> \n",
      "03-Jun-21 13:48:52 | Epoch 0169 | lr: 0.00000163 | Time(s): 40.4804\n",
      "03-Jun-21 13:48:52 | == train Loss <unroll: 0.7723 | vae : 2.2424 | kl : 0.0000>\n",
      "03-Jun-21 13:48:52 | == gmse <train: 0.0697 | val: 0.0642> \n",
      "03-Jun-21 13:49:32 | Epoch 0170 | lr: 0.00000155 | Time(s): 40.4796\n",
      "03-Jun-21 13:49:32 | == train Loss <unroll: 0.7722 | vae : 2.2421 | kl : 0.0000>\n",
      "03-Jun-21 13:49:32 | == gmse <train: 0.0697 | val: 0.0639> \n",
      "03-Jun-21 13:50:13 | Epoch 0171 | lr: 0.00000147 | Time(s): 40.4791\n",
      "03-Jun-21 13:50:13 | == train Loss <unroll: 0.7721 | vae : 2.2418 | kl : 0.0000>\n",
      "03-Jun-21 13:50:13 | == gmse <train: 0.0696 | val: 0.0616> \n",
      "03-Jun-21 13:50:53 | Epoch 0172 | lr: 0.00000140 | Time(s): 40.4778\n",
      "03-Jun-21 13:50:53 | == train Loss <unroll: 0.7714 | vae : 2.2413 | kl : 0.0000>\n",
      "03-Jun-21 13:50:53 | == gmse <train: 0.0689 | val: 0.0615> \n",
      "03-Jun-21 13:51:33 | Epoch 0173 | lr: 0.00000133 | Time(s): 40.4760\n",
      "03-Jun-21 13:51:33 | == train Loss <unroll: 0.7713 | vae : 2.2408 | kl : 0.0000>\n",
      "03-Jun-21 13:51:33 | == gmse <train: 0.0688 | val: 0.0619> \n",
      "03-Jun-21 13:52:13 | Epoch 0174 | lr: 0.00000126 | Time(s): 40.4756\n",
      "03-Jun-21 13:52:13 | == train Loss <unroll: 0.7712 | vae : 2.2403 | kl : 0.0000>\n",
      "03-Jun-21 13:52:13 | == gmse <train: 0.0687 | val: 0.0617> \n",
      "03-Jun-21 13:52:54 | Epoch 0175 | lr: 0.00000120 | Time(s): 40.4750\n",
      "03-Jun-21 13:52:54 | == train Loss <unroll: 0.7712 | vae : 2.2400 | kl : 0.0000>\n",
      "03-Jun-21 13:52:54 | == gmse <train: 0.0687 | val: 0.0616> \n",
      "03-Jun-21 13:53:34 | Epoch 0176 | lr: 0.00000114 | Time(s): 40.4746\n",
      "03-Jun-21 13:53:34 | == train Loss <unroll: 0.7711 | vae : 2.2397 | kl : 0.0000>\n",
      "03-Jun-21 13:53:34 | == gmse <train: 0.0686 | val: 0.0617> \n",
      "03-Jun-21 13:54:14 | Epoch 0177 | lr: 0.00000108 | Time(s): 40.4726\n",
      "03-Jun-21 13:54:14 | == train Loss <unroll: 0.7711 | vae : 2.2392 | kl : 0.0000>\n",
      "03-Jun-21 13:54:14 | == gmse <train: 0.0686 | val: 0.0616> \n",
      "03-Jun-21 13:54:54 | Epoch 0178 | lr: 0.00000103 | Time(s): 40.4707\n",
      "03-Jun-21 13:54:54 | == train Loss <unroll: 0.7710 | vae : 2.2389 | kl : 0.0000>\n",
      "03-Jun-21 13:54:54 | == gmse <train: 0.0685 | val: 0.0617> \n",
      "03-Jun-21 13:55:35 | Epoch 0179 | lr: 0.00000098 | Time(s): 40.4708\n",
      "03-Jun-21 13:55:35 | == train Loss <unroll: 0.7710 | vae : 2.2384 | kl : 0.0000>\n",
      "03-Jun-21 13:55:35 | == gmse <train: 0.0685 | val: 0.0617> \n",
      "03-Jun-21 13:56:15 | Epoch 0180 | lr: 0.00000093 | Time(s): 40.4697\n",
      "03-Jun-21 13:56:15 | == train Loss <unroll: 0.7710 | vae : 2.2380 | kl : 0.0000>\n",
      "03-Jun-21 13:56:15 | == gmse <train: 0.0685 | val: 0.0614> \n",
      "03-Jun-21 13:56:55 | Epoch 0181 | lr: 0.00000088 | Time(s): 40.4687\n",
      "03-Jun-21 13:56:55 | == train Loss <unroll: 0.7709 | vae : 2.2378 | kl : 0.0000>\n",
      "03-Jun-21 13:56:55 | == gmse <train: 0.0684 | val: 0.0616> \n",
      "03-Jun-21 13:57:36 | Epoch 0182 | lr: 0.00000084 | Time(s): 40.4669\n",
      "03-Jun-21 13:57:36 | == train Loss <unroll: 0.7709 | vae : 2.2374 | kl : 0.0000>\n",
      "03-Jun-21 13:57:36 | == gmse <train: 0.0684 | val: 0.0615> \n",
      "03-Jun-21 13:58:16 | Epoch 0183 | lr: 0.00000080 | Time(s): 40.4650\n",
      "03-Jun-21 13:58:16 | == train Loss <unroll: 0.7709 | vae : 2.2372 | kl : 0.0000>\n",
      "03-Jun-21 13:58:16 | == gmse <train: 0.0684 | val: 0.0615> \n",
      "03-Jun-21 13:58:56 | Epoch 0184 | lr: 0.00000076 | Time(s): 40.4634\n",
      "03-Jun-21 13:58:56 | == train Loss <unroll: 0.7709 | vae : 2.2369 | kl : 0.0000>\n",
      "03-Jun-21 13:58:56 | == gmse <train: 0.0684 | val: 0.0611> \n",
      "03-Jun-21 13:59:36 | Epoch 0185 | lr: 0.00000072 | Time(s): 40.4618\n",
      "03-Jun-21 13:59:36 | == train Loss <unroll: 0.7708 | vae : 2.2366 | kl : 0.0000>\n",
      "03-Jun-21 13:59:36 | == gmse <train: 0.0683 | val: 0.0614> \n",
      "03-Jun-21 14:00:16 | Epoch 0186 | lr: 0.00000068 | Time(s): 40.4603\n",
      "03-Jun-21 14:00:16 | == train Loss <unroll: 0.7708 | vae : 2.2365 | kl : 0.0000>\n",
      "03-Jun-21 14:00:16 | == gmse <train: 0.0683 | val: 0.0613> \n",
      "03-Jun-21 14:00:56 | Epoch 0187 | lr: 0.00000065 | Time(s): 40.4575\n",
      "03-Jun-21 14:00:56 | == train Loss <unroll: 0.7708 | vae : 2.2363 | kl : 0.0000>\n",
      "03-Jun-21 14:00:56 | == gmse <train: 0.0683 | val: 0.0613> \n",
      "03-Jun-21 14:01:36 | Epoch 0188 | lr: 0.00000062 | Time(s): 40.4554\n",
      "03-Jun-21 14:01:36 | == train Loss <unroll: 0.7708 | vae : 2.2359 | kl : 0.0000>\n",
      "03-Jun-21 14:01:36 | == gmse <train: 0.0683 | val: 0.0613> \n",
      "03-Jun-21 14:02:17 | Epoch 0189 | lr: 0.00000059 | Time(s): 40.4558\n",
      "03-Jun-21 14:02:17 | == train Loss <unroll: 0.7708 | vae : 2.2358 | kl : 0.0000>\n",
      "03-Jun-21 14:02:17 | == gmse <train: 0.0683 | val: 0.0613> \n",
      "03-Jun-21 14:02:57 | Epoch 0190 | lr: 0.00000056 | Time(s): 40.4563\n",
      "03-Jun-21 14:02:57 | == train Loss <unroll: 0.7707 | vae : 2.2356 | kl : 0.0000>\n",
      "03-Jun-21 14:02:57 | == gmse <train: 0.0683 | val: 0.0613> \n",
      "03-Jun-21 14:03:38 | Epoch 0191 | lr: 0.00000053 | Time(s): 40.4552\n",
      "03-Jun-21 14:03:38 | == train Loss <unroll: 0.7707 | vae : 2.2353 | kl : 0.0000>\n",
      "03-Jun-21 14:03:38 | == gmse <train: 0.0682 | val: 0.0612> \n",
      "03-Jun-21 14:04:18 | Epoch 0192 | lr: 0.00000050 | Time(s): 40.4541\n",
      "03-Jun-21 14:04:18 | == train Loss <unroll: 0.7707 | vae : 2.2352 | kl : 0.0000>\n",
      "03-Jun-21 14:04:18 | == gmse <train: 0.0682 | val: 0.0613> \n",
      "03-Jun-21 14:04:58 | Epoch 0193 | lr: 0.00000048 | Time(s): 40.4538\n",
      "03-Jun-21 14:04:58 | == train Loss <unroll: 0.7707 | vae : 2.2349 | kl : 0.0000>\n",
      "03-Jun-21 14:04:58 | == gmse <train: 0.0682 | val: 0.0611> \n",
      "03-Jun-21 14:05:39 | Epoch 0194 | lr: 0.00000045 | Time(s): 40.4536\n",
      "03-Jun-21 14:05:39 | == train Loss <unroll: 0.7707 | vae : 2.2348 | kl : 0.0000>\n",
      "03-Jun-21 14:05:39 | == gmse <train: 0.0682 | val: 0.0614> \n",
      "03-Jun-21 14:06:19 | Epoch 0195 | lr: 0.00000043 | Time(s): 40.4527\n",
      "03-Jun-21 14:06:19 | == train Loss <unroll: 0.7707 | vae : 2.2348 | kl : 0.0000>\n",
      "03-Jun-21 14:06:19 | == gmse <train: 0.0682 | val: 0.0612> \n",
      "03-Jun-21 14:07:00 | Epoch 0196 | lr: 0.00000041 | Time(s): 40.4550\n",
      "03-Jun-21 14:07:00 | == train Loss <unroll: 0.7707 | vae : 2.2345 | kl : 0.0000>\n",
      "03-Jun-21 14:07:00 | == gmse <train: 0.0682 | val: 0.0612> \n",
      "03-Jun-21 14:07:40 | Epoch 0197 | lr: 0.00000039 | Time(s): 40.4521\n",
      "03-Jun-21 14:07:40 | == train Loss <unroll: 0.7706 | vae : 2.2344 | kl : 0.0000>\n",
      "03-Jun-21 14:07:40 | == gmse <train: 0.0682 | val: 0.0613> \n",
      "03-Jun-21 14:08:20 | Epoch 0198 | lr: 0.00000037 | Time(s): 40.4527\n",
      "03-Jun-21 14:08:20 | == train Loss <unroll: 0.7706 | vae : 2.2342 | kl : 0.0000>\n",
      "03-Jun-21 14:08:20 | == gmse <train: 0.0682 | val: 0.0612> \n",
      "03-Jun-21 14:09:01 | Epoch 0199 | lr: 0.00000035 | Time(s): 40.4545\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 14:09:01 | == train Loss <unroll: 0.7706 | vae : 2.2341 | kl : 0.0000>\n",
      "03-Jun-21 14:09:01 | == gmse <train: 0.0682 | val: 0.0612> \n",
      "03-Jun-21 14:09:42 | Epoch 0200 | lr: 0.00000033 | Time(s): 40.4540\n",
      "03-Jun-21 14:09:42 | == train Loss <unroll: 0.7706 | vae : 2.2339 | kl : 0.0000>\n",
      "03-Jun-21 14:09:42 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:10:22 | Epoch 0201 | lr: 0.00000032 | Time(s): 40.4559\n",
      "03-Jun-21 14:10:22 | == train Loss <unroll: 0.7706 | vae : 2.2338 | kl : 0.0000>\n",
      "03-Jun-21 14:10:22 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:11:03 | Epoch 0202 | lr: 0.00000030 | Time(s): 40.4561\n",
      "03-Jun-21 14:11:03 | == train Loss <unroll: 0.7706 | vae : 2.2337 | kl : 0.0000>\n",
      "03-Jun-21 14:11:03 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:11:43 | Epoch 0203 | lr: 0.00000029 | Time(s): 40.4563\n",
      "03-Jun-21 14:11:43 | == train Loss <unroll: 0.7706 | vae : 2.2336 | kl : 0.0000>\n",
      "03-Jun-21 14:11:43 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:12:24 | Epoch 0204 | lr: 0.00000027 | Time(s): 40.4560\n",
      "03-Jun-21 14:12:24 | == train Loss <unroll: 0.7706 | vae : 2.2335 | kl : 0.0000>\n",
      "03-Jun-21 14:12:24 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:13:04 | Epoch 0205 | lr: 0.00000026 | Time(s): 40.4543\n",
      "03-Jun-21 14:13:04 | == train Loss <unroll: 0.7706 | vae : 2.2335 | kl : 0.0000>\n",
      "03-Jun-21 14:13:04 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:13:44 | Epoch 0206 | lr: 0.00000024 | Time(s): 40.4509\n",
      "03-Jun-21 14:13:44 | == train Loss <unroll: 0.7706 | vae : 2.2333 | kl : 0.0000>\n",
      "03-Jun-21 14:13:44 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:14:24 | Epoch 0207 | lr: 0.00000023 | Time(s): 40.4517\n",
      "03-Jun-21 14:14:24 | == train Loss <unroll: 0.7706 | vae : 2.2332 | kl : 0.0000>\n",
      "03-Jun-21 14:14:24 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:15:05 | Epoch 0208 | lr: 0.00000022 | Time(s): 40.4515\n",
      "03-Jun-21 14:15:05 | == train Loss <unroll: 0.7706 | vae : 2.2331 | kl : 0.0000>\n",
      "03-Jun-21 14:15:05 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:15:45 | Epoch 0209 | lr: 0.00000021 | Time(s): 40.4516\n",
      "03-Jun-21 14:15:45 | == train Loss <unroll: 0.7705 | vae : 2.2330 | kl : 0.0000>\n",
      "03-Jun-21 14:15:45 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:16:25 | Epoch 0210 | lr: 0.00000020 | Time(s): 40.4496\n",
      "03-Jun-21 14:16:25 | == train Loss <unroll: 0.7705 | vae : 2.2329 | kl : 0.0000>\n",
      "03-Jun-21 14:16:25 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:17:05 | Epoch 0211 | lr: 0.00000019 | Time(s): 40.4483\n",
      "03-Jun-21 14:17:05 | == train Loss <unroll: 0.7705 | vae : 2.2329 | kl : 0.0000>\n",
      "03-Jun-21 14:17:05 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:17:46 | Epoch 0212 | lr: 0.00000018 | Time(s): 40.4474\n",
      "03-Jun-21 14:17:46 | == train Loss <unroll: 0.7705 | vae : 2.2328 | kl : 0.0000>\n",
      "03-Jun-21 14:17:46 | == gmse <train: 0.0681 | val: 0.0610> \n",
      "03-Jun-21 14:18:26 | Epoch 0213 | lr: 0.00000017 | Time(s): 40.4456\n",
      "03-Jun-21 14:18:26 | == train Loss <unroll: 0.7705 | vae : 2.2328 | kl : 0.0000>\n",
      "03-Jun-21 14:18:26 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:19:06 | Epoch 0214 | lr: 0.00000016 | Time(s): 40.4455\n",
      "03-Jun-21 14:19:06 | == train Loss <unroll: 0.7705 | vae : 2.2327 | kl : 0.0000>\n",
      "03-Jun-21 14:19:06 | == gmse <train: 0.0681 | val: 0.0612> \n",
      "03-Jun-21 14:19:46 | Epoch 0215 | lr: 0.00000015 | Time(s): 40.4441\n",
      "03-Jun-21 14:19:46 | == train Loss <unroll: 0.7705 | vae : 2.2326 | kl : 0.0000>\n",
      "03-Jun-21 14:19:46 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:20:27 | Epoch 0216 | lr: 0.00000015 | Time(s): 40.4433\n",
      "03-Jun-21 14:20:27 | == train Loss <unroll: 0.7705 | vae : 2.2326 | kl : 0.0000>\n",
      "03-Jun-21 14:20:27 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:21:07 | Epoch 0217 | lr: 0.00000014 | Time(s): 40.4437\n",
      "03-Jun-21 14:21:07 | == train Loss <unroll: 0.7705 | vae : 2.2325 | kl : 0.0000>\n",
      "03-Jun-21 14:21:07 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:21:47 | Epoch 0218 | lr: 0.00000013 | Time(s): 40.4418\n",
      "03-Jun-21 14:21:47 | == train Loss <unroll: 0.7705 | vae : 2.2324 | kl : 0.0000>\n",
      "03-Jun-21 14:21:47 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:22:27 | Epoch 0219 | lr: 0.00000013 | Time(s): 40.4402\n",
      "03-Jun-21 14:22:27 | == train Loss <unroll: 0.7705 | vae : 2.2324 | kl : 0.0000>\n",
      "03-Jun-21 14:22:27 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:23:08 | Epoch 0220 | lr: 0.00000012 | Time(s): 40.4402\n",
      "03-Jun-21 14:23:08 | == train Loss <unroll: 0.7705 | vae : 2.2323 | kl : 0.0000>\n",
      "03-Jun-21 14:23:08 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:23:48 | Epoch 0221 | lr: 0.00000011 | Time(s): 40.4384\n",
      "03-Jun-21 14:23:48 | == train Loss <unroll: 0.7705 | vae : 2.2330 | kl : 0.0000>\n",
      "03-Jun-21 14:23:48 | == gmse <train: 0.0681 | val: 0.0611> \n",
      "03-Jun-21 14:24:28 | Epoch 0222 | lr: 0.00000011 | Time(s): 40.4370\n",
      "03-Jun-21 14:24:28 | == train Loss <unroll: 0.7705 | vae : 2.2322 | kl : 0.0000>\n",
      "03-Jun-21 14:24:28 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:25:08 | Epoch 0223 | lr: 0.00000010 | Time(s): 40.4378\n",
      "03-Jun-21 14:25:08 | == train Loss <unroll: 0.7705 | vae : 2.2321 | kl : 0.0000>\n",
      "03-Jun-21 14:25:08 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:25:49 | Epoch 0224 | lr: 0.00000010 | Time(s): 40.4377\n",
      "03-Jun-21 14:25:49 | == train Loss <unroll: 0.7705 | vae : 2.2322 | kl : 0.0000>\n",
      "03-Jun-21 14:25:49 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:26:29 | Epoch 0225 | lr: 0.00000009 | Time(s): 40.4363\n",
      "03-Jun-21 14:26:29 | == train Loss <unroll: 0.7705 | vae : 2.2321 | kl : 0.0000>\n",
      "03-Jun-21 14:26:29 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:27:10 | Epoch 0226 | lr: 0.00000009 | Time(s): 40.4371\n",
      "03-Jun-21 14:27:10 | == train Loss <unroll: 0.7705 | vae : 2.2321 | kl : 0.0000>\n",
      "03-Jun-21 14:27:10 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:27:49 | Epoch 0227 | lr: 0.00000008 | Time(s): 40.4346\n",
      "03-Jun-21 14:27:49 | == train Loss <unroll: 0.7705 | vae : 2.2320 | kl : 0.0000>\n",
      "03-Jun-21 14:27:49 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:28:30 | Epoch 0228 | lr: 0.00000008 | Time(s): 40.4343\n",
      "03-Jun-21 14:28:30 | == train Loss <unroll: 0.7705 | vae : 2.2320 | kl : 0.0000>\n",
      "03-Jun-21 14:28:30 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:29:10 | Epoch 0229 | lr: 0.00000008 | Time(s): 40.4347\n",
      "03-Jun-21 14:29:10 | == train Loss <unroll: 0.7705 | vae : 2.2320 | kl : 0.0000>\n",
      "03-Jun-21 14:29:10 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:29:51 | Epoch 0230 | lr: 0.00000007 | Time(s): 40.4348\n",
      "03-Jun-21 14:29:51 | == train Loss <unroll: 0.7705 | vae : 2.2319 | kl : 0.0000>\n",
      "03-Jun-21 14:29:51 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:30:31 | Epoch 0231 | lr: 0.00000007 | Time(s): 40.4335\n",
      "03-Jun-21 14:30:31 | == train Loss <unroll: 0.7705 | vae : 2.2319 | kl : 0.0000>\n",
      "03-Jun-21 14:30:31 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:31:11 | Epoch 0232 | lr: 0.00000006 | Time(s): 40.4336\n",
      "03-Jun-21 14:31:11 | == train Loss <unroll: 0.7705 | vae : 2.2319 | kl : 0.0000>\n",
      "03-Jun-21 14:31:11 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:31:52 | Epoch 0233 | lr: 0.00000006 | Time(s): 40.4326\n",
      "03-Jun-21 14:31:52 | == train Loss <unroll: 0.7705 | vae : 2.2318 | kl : 0.0000>\n",
      "03-Jun-21 14:31:52 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:32:32 | Epoch 0234 | lr: 0.00000006 | Time(s): 40.4306\n",
      "03-Jun-21 14:32:32 | == train Loss <unroll: 0.7705 | vae : 2.2319 | kl : 0.0000>\n",
      "03-Jun-21 14:32:32 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:33:12 | Epoch 0235 | lr: 0.00000006 | Time(s): 40.4305\n",
      "03-Jun-21 14:33:12 | == train Loss <unroll: 0.7705 | vae : 2.2318 | kl : 0.0000>\n",
      "03-Jun-21 14:33:12 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:33:52 | Epoch 0236 | lr: 0.00000005 | Time(s): 40.4298\n",
      "03-Jun-21 14:33:52 | == train Loss <unroll: 0.7705 | vae : 2.2317 | kl : 0.0000>\n",
      "03-Jun-21 14:33:52 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:34:33 | Epoch 0237 | lr: 0.00000005 | Time(s): 40.4293\n",
      "03-Jun-21 14:34:33 | == train Loss <unroll: 0.7705 | vae : 2.2318 | kl : 0.0000>\n",
      "03-Jun-21 14:34:33 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:35:13 | Epoch 0238 | lr: 0.00000005 | Time(s): 40.4284\n",
      "03-Jun-21 14:35:13 | == train Loss <unroll: 0.7705 | vae : 2.2318 | kl : 0.0000>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 14:35:13 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:35:53 | Epoch 0239 | lr: 0.00000005 | Time(s): 40.4280\n",
      "03-Jun-21 14:35:53 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:35:53 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:36:34 | Epoch 0240 | lr: 0.00000004 | Time(s): 40.4290\n",
      "03-Jun-21 14:36:34 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:36:34 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:37:14 | Epoch 0241 | lr: 0.00000004 | Time(s): 40.4287\n",
      "03-Jun-21 14:37:14 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:37:14 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:37:55 | Epoch 0242 | lr: 0.00000004 | Time(s): 40.4288\n",
      "03-Jun-21 14:37:55 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:37:55 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:38:35 | Epoch 0243 | lr: 0.00000004 | Time(s): 40.4273\n",
      "03-Jun-21 14:38:35 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:38:35 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:39:15 | Epoch 0244 | lr: 0.00000003 | Time(s): 40.4268\n",
      "03-Jun-21 14:39:15 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:39:15 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:39:55 | Epoch 0245 | lr: 0.00000003 | Time(s): 40.4263\n",
      "03-Jun-21 14:39:55 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:39:55 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:40:36 | Epoch 0246 | lr: 0.00000003 | Time(s): 40.4257\n",
      "03-Jun-21 14:40:36 | == train Loss <unroll: 0.7705 | vae : 2.2317 | kl : 0.0000>\n",
      "03-Jun-21 14:40:36 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:41:16 | Epoch 0247 | lr: 0.00000003 | Time(s): 40.4254\n",
      "03-Jun-21 14:41:16 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:41:16 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 14:41:56 | Epoch 0248 | lr: 0.00000003 | Time(s): 40.4242\n",
      "03-Jun-21 14:41:56 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:41:56 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:42:36 | Epoch 0249 | lr: 0.00000003 | Time(s): 40.4219\n",
      "03-Jun-21 14:42:36 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:42:36 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:43:16 | Epoch 0250 | lr: 0.00000003 | Time(s): 40.4204\n",
      "03-Jun-21 14:43:16 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:43:16 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:43:56 | Epoch 0251 | lr: 0.00000002 | Time(s): 40.4202\n",
      "03-Jun-21 14:43:56 | == train Loss <unroll: 0.7705 | vae : 2.2316 | kl : 0.0000>\n",
      "03-Jun-21 14:43:56 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:44:36 | Epoch 0252 | lr: 0.00000002 | Time(s): 40.4190\n",
      "03-Jun-21 14:44:36 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:44:36 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:45:17 | Epoch 0253 | lr: 0.00000002 | Time(s): 40.4196\n",
      "03-Jun-21 14:45:17 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:45:17 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:45:57 | Epoch 0254 | lr: 0.00000002 | Time(s): 40.4180\n",
      "03-Jun-21 14:45:57 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:45:57 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:46:37 | Epoch 0255 | lr: 0.00000002 | Time(s): 40.4164\n",
      "03-Jun-21 14:46:37 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:46:37 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:47:17 | Epoch 0256 | lr: 0.00000002 | Time(s): 40.4127\n",
      "03-Jun-21 14:47:17 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:47:17 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:47:53 | Epoch 0257 | lr: 0.00000002 | Time(s): 40.3970\n",
      "03-Jun-21 14:47:53 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:47:53 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:48:29 | Epoch 0258 | lr: 0.00000002 | Time(s): 40.3807\n",
      "03-Jun-21 14:48:29 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:48:29 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:49:05 | Epoch 0259 | lr: 0.00000002 | Time(s): 40.3643\n",
      "03-Jun-21 14:49:05 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:49:05 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:49:42 | Epoch 0260 | lr: 0.00000002 | Time(s): 40.3488\n",
      "03-Jun-21 14:49:42 | == train Loss <unroll: 0.7705 | vae : 2.2315 | kl : 0.0000>\n",
      "03-Jun-21 14:49:42 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:50:18 | Epoch 0261 | lr: 0.00000001 | Time(s): 40.3336\n",
      "03-Jun-21 14:50:18 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:50:18 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:50:54 | Epoch 0262 | lr: 0.00000001 | Time(s): 40.3172\n",
      "03-Jun-21 14:50:54 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:50:54 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:51:30 | Epoch 0263 | lr: 0.00000001 | Time(s): 40.3017\n",
      "03-Jun-21 14:51:30 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:51:30 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:52:06 | Epoch 0264 | lr: 0.00000001 | Time(s): 40.2861\n",
      "03-Jun-21 14:52:06 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:52:06 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:52:43 | Epoch 0265 | lr: 0.00000001 | Time(s): 40.2711\n",
      "03-Jun-21 14:52:43 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:52:43 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:53:19 | Epoch 0266 | lr: 0.00000001 | Time(s): 40.2559\n",
      "03-Jun-21 14:53:19 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:53:19 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:53:55 | Epoch 0267 | lr: 0.00000001 | Time(s): 40.2409\n",
      "03-Jun-21 14:53:55 | == train Loss <unroll: 0.7705 | vae : 2.2314 | kl : 0.0000>\n",
      "03-Jun-21 14:53:55 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:54:32 | Epoch 0268 | lr: 0.00000001 | Time(s): 40.2260\n",
      "03-Jun-21 14:54:32 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:54:32 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:55:08 | Epoch 0269 | lr: 0.00000001 | Time(s): 40.2103\n",
      "03-Jun-21 14:55:08 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:55:08 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:55:43 | Epoch 0270 | lr: 0.00000001 | Time(s): 40.1939\n",
      "03-Jun-21 14:55:43 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:55:43 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:56:20 | Epoch 0271 | lr: 0.00000001 | Time(s): 40.1797\n",
      "03-Jun-21 14:56:20 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:56:20 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:56:56 | Epoch 0272 | lr: 0.00000001 | Time(s): 40.1656\n",
      "03-Jun-21 14:56:56 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 14:56:56 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:57:32 | Epoch 0273 | lr: 0.00000001 | Time(s): 40.1510\n",
      "03-Jun-21 14:57:32 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:57:32 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:58:08 | Epoch 0274 | lr: 0.00000001 | Time(s): 40.1359\n",
      "03-Jun-21 14:58:08 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:58:08 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:58:44 | Epoch 0275 | lr: 0.00000001 | Time(s): 40.1219\n",
      "03-Jun-21 14:58:44 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 14:58:45 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:59:21 | Epoch 0276 | lr: 0.00000001 | Time(s): 40.1075\n",
      "03-Jun-21 14:59:21 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 14:59:21 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 14:59:57 | Epoch 0277 | lr: 0.00000001 | Time(s): 40.0940\n",
      "03-Jun-21 14:59:57 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 14:59:57 | == gmse <train: 0.0680 | val: 0.0610> \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 15:00:33 | Epoch 0278 | lr: 0.00000001 | Time(s): 40.0801\n",
      "03-Jun-21 15:00:33 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 15:00:33 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:01:09 | Epoch 0279 | lr: 0.00000001 | Time(s): 40.0662\n",
      "03-Jun-21 15:01:09 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:01:09 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:01:46 | Epoch 0280 | lr: 0.00000001 | Time(s): 40.0523\n",
      "03-Jun-21 15:01:46 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:01:46 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:02:22 | Epoch 0281 | lr: 0.00000001 | Time(s): 40.0381\n",
      "03-Jun-21 15:02:22 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:02:22 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:02:58 | Epoch 0282 | lr: 0.00000000 | Time(s): 40.0255\n",
      "03-Jun-21 15:02:58 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:02:58 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:03:34 | Epoch 0283 | lr: 0.00000000 | Time(s): 40.0119\n",
      "03-Jun-21 15:03:34 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:03:34 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:04:11 | Epoch 0284 | lr: 0.00000000 | Time(s): 39.9987\n",
      "03-Jun-21 15:04:11 | == train Loss <unroll: 0.7705 | vae : 2.2313 | kl : 0.0000>\n",
      "03-Jun-21 15:04:11 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:04:47 | Epoch 0285 | lr: 0.00000000 | Time(s): 39.9854\n",
      "03-Jun-21 15:04:47 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:04:47 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:05:23 | Epoch 0286 | lr: 0.00000000 | Time(s): 39.9725\n",
      "03-Jun-21 15:05:23 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:05:23 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:06:00 | Epoch 0287 | lr: 0.00000000 | Time(s): 39.9601\n",
      "03-Jun-21 15:06:00 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:06:00 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:06:36 | Epoch 0288 | lr: 0.00000000 | Time(s): 39.9470\n",
      "03-Jun-21 15:06:36 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:06:36 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:07:12 | Epoch 0289 | lr: 0.00000000 | Time(s): 39.9342\n",
      "03-Jun-21 15:07:12 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:07:12 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:07:48 | Epoch 0290 | lr: 0.00000000 | Time(s): 39.9213\n",
      "03-Jun-21 15:07:48 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:07:48 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:08:24 | Epoch 0291 | lr: 0.00000000 | Time(s): 39.9086\n",
      "03-Jun-21 15:08:24 | == train Loss <unroll: 0.7705 | vae : 2.2311 | kl : 0.0000>\n",
      "03-Jun-21 15:08:24 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:09:01 | Epoch 0292 | lr: 0.00000000 | Time(s): 39.8968\n",
      "03-Jun-21 15:09:01 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:09:01 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:09:37 | Epoch 0293 | lr: 0.00000000 | Time(s): 39.8839\n",
      "03-Jun-21 15:09:37 | == train Loss <unroll: 0.7705 | vae : 2.2311 | kl : 0.0000>\n",
      "03-Jun-21 15:09:37 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:10:13 | Epoch 0294 | lr: 0.00000000 | Time(s): 39.8715\n",
      "03-Jun-21 15:10:13 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:10:13 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 15:10:49 | Epoch 0295 | lr: 0.00000000 | Time(s): 39.8592\n",
      "03-Jun-21 15:10:49 | == train Loss <unroll: 0.7705 | vae : 2.2311 | kl : 0.0000>\n",
      "03-Jun-21 15:10:49 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:11:26 | Epoch 0296 | lr: 0.00000000 | Time(s): 39.8471\n",
      "03-Jun-21 15:11:26 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:11:26 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:12:02 | Epoch 0297 | lr: 0.00000000 | Time(s): 39.8352\n",
      "03-Jun-21 15:12:02 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:12:02 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:12:38 | Epoch 0298 | lr: 0.00000000 | Time(s): 39.8230\n",
      "03-Jun-21 15:12:38 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:12:38 | == gmse <train: 0.0680 | val: 0.0611> \n",
      "03-Jun-21 15:13:15 | Epoch 0299 | lr: 0.00000000 | Time(s): 39.8112\n",
      "03-Jun-21 15:13:15 | == train Loss <unroll: 0.7705 | vae : 2.2312 | kl : 0.0000>\n",
      "03-Jun-21 15:13:15 | == gmse <train: 0.0680 | val: 0.0610> \n",
      "03-Jun-21 15:13:51 | Epoch 0300 | lr: 0.00000000 | Time(s): 39.7994\n",
      "03-Jun-21 15:13:51 | == train Loss <unroll: 0.7705 | vae : 2.2311 | kl : 0.0000>\n",
      "03-Jun-21 15:13:51 | == gmse <train: 0.0680 | val: 0.0610> \n"
     ]
    }
   ],
   "source": [
    "# Training:\n",
    "n_epochs = 300\n",
    "\n",
    "dur = []\n",
    "\n",
    "epoch_train_gmse = []\n",
    "epoch_val_gmse = []\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "\n",
    "    train_unrolling_loss, train_vae_loss, train_kl_loss, train_gmse, val_gmse = [], [], [], [], []\n",
    "\n",
    "    t0 = time.time()\n",
    "\n",
    "    net.train()\n",
    "    for z, w_gt_batch in train_loader:\n",
    "        z = z.to(device)\n",
    "        w_gt_batch = w_gt_batch.to(device)\n",
    "        this_batch_size = w_gt_batch.size()[0]\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        w_list, vae_loss, vae_kl, _ = net.forward(z, w_gt_batch, threshold=1e-04, kl_hyper=1)\n",
    "\n",
    "        unrolling_loss = torch.mean(\n",
    "            torch.stack([acc_loss(w_list[i, :, :], w_gt_batch[i, :], dn=0.9) for i in range(batch_size)])\n",
    "        )\n",
    "\n",
    "        loss = unrolling_loss + vae_loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        w_pred = w_list[:, num_unroll - 1, :]\n",
    "        gmse = gmse_loss_batch_mean(w_pred, w_gt_batch)\n",
    "\n",
    "        train_gmse.append(gmse.item())\n",
    "        train_unrolling_loss.append(unrolling_loss.item())\n",
    "        train_vae_loss.append(vae_loss.item())\n",
    "        train_kl_loss.append(vae_kl.item())\n",
    "\n",
    "    scheduler.step()\n",
    "\n",
    "    net.eval()\n",
    "    for z, w_gt_batch in val_loader:\n",
    "        z = z.to(device)\n",
    "        w_gt_batch = w_gt_batch.to(device)\n",
    "\n",
    "        w_list = net.validation(z, threshold=1e-04)\n",
    "        w_pred = torch.clamp(w_list[:, num_unroll - 1, :], min=0)\n",
    "        loss = gmse_loss_batch_mean(w_pred, w_gt_batch)\n",
    "        val_gmse.append(loss.item())\n",
    "\n",
    "    dur.append(time.time() - t0)\n",
    "\n",
    "    logging.info(\"Epoch {:04d} | lr: {:04.8f} | Time(s): {:.4f}\".format(epoch + 1, scheduler.get_lr()[0], np.mean(dur)))\n",
    "    logging.info(\"== train Loss <unroll: {:04.4f} | vae : {:04.4f} | kl : {:04.4f}>\".format(np.mean(train_unrolling_loss),\n",
    "                                                                                  np.mean(train_vae_loss),\n",
    "                                                                                  np.mean(train_kl_loss)))\n",
    "    logging.info(\"== gmse <train: {:04.4f} | val: {:04.4f}> \".format(np.mean(train_gmse), np.mean(val_gmse)))\n",
    "\n",
    "    epoch_train_gmse.append(np.mean(train_gmse))\n",
    "    epoch_val_gmse.append(np.mean(val_gmse))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEICAYAAABVv+9nAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgYElEQVR4nO3deZRcZ3nn8e9TVb13S91St6zdWmws29gWdmODzWYw2JiwzAkBQ+AAhxmTGZhhGwiEGULCDEwyhAkzh2RwWMYJDkuMHRzjBLxgHCdg05KFLUsyRrZlbZa6tfWmqq7lmT/urVa71dWb6lZ13fp9zunTVbeq732vr/Xrp9967/uauyMiIvGTqHYDREQkGgp4EZGYUsCLiMSUAl5EJKYU8CIiMaWAFxGJqcgC3szOM7NtE74GzewjUR1PRESezyoxDt7MksB+4Ap331Pqfd3d3b5u3brI2yMiEhdbtmwZcPeeqV5LVagNrwF2TxfuAOvWraOvr69CTRIRqX1mVjJXK9UHfwPwnaleMLMbzazPzPr6+/sr1BwRkfiLPODNrBF4E/B3U73u7je5e6+79/b0TPlXhoiIzEMlKvjXA1vd/VAFjiUiIqFKBPw7KNE9IyIi0Yk04M2sFXgtcFuUxxERkdNFOorG3UeBpVEeQ0REpqY7WUVEYqquAr7vmaM88uyxajdDRKQi6irg//tdO/n3395KNl+odlNERCJXVwF/cizPc4Np7nrsYLWbIiISuboK+HQ2D8A3H3warUUrInFXVwGfyRVoaUjyq30n2Kq+eBGJuboK+HQ2z/UXrWBRc4q//vm0856JiNS8ugr4TK7AkrYGLl7dybNHR6vdHBGRSNVNwLs76WyeplSSxlSCsZxG0ohIvNVNwOcKTsGhuSFBY1IBLyLxVzcBnwkDvVjBayy8iMRd3QR8cYhkU0NCXTQiUhfqJuCLFXxzsQ9eFbyIxFzdBPzzKvhkYjzwRUTiqm4CPpM91QffpC4aEakD9RPwuUl98PmCpisQkVirm4BPZyf0wScTuAdDJ0VE4qpuAn5yBQ+om0ZEYq1uAj493gevgBeR+lA3AV+s4JsbkqcCXkMlRSTG6ifgJ1bwSVXwIhJ/kQa8mXWa2a1mtsvMdprZS6M83nSmquA1Fl5E4iwV8f6/AvyTu7/VzBqB1oiPV9KpuWgSNKkPXkTqQGQBb2aLgFcA7wVw9zFgLKrjzaR4J6v64EWkXkTZRbMB6Ae+ZWaPmNnXzaxt8pvM7EYz6zOzvv7+/sgak8kVSBikEkZjMgmogheReIsy4FPApcBfuvuLgBHgU5Pf5O43uXuvu/f29PRE1pjiYh9mpmGSIlIXogz4fcA+d38ofH4rQeBXRSZXoLkhON1TXTT5ajVHRCRykQW8uz8H7DWz88JNrwF2RHW8mWSyBZpSQdeMhkmKSD2IehTNfwRuCUfQPAW8L+LjlZTO5U+r4DVMUkTiLNKAd/dtQG+Ux5itiRW8hkmKSD2omztZp6rgNUxSROKsbgJeffAiUm/qJuDTuTxNkyr4P/qHHXz4u49Us1kiIpGpm4A/OZanpSGs4FOnTvuH2w5Uq0kiIpGqm4AfSufoaG4AgrtZzarcIBGRiNVRwGfpaA4GDZkZDcm6OXURqVN1kXLuznAmNx7wAAWtxyoiMVcXAT8ylqfgPC/gteC2iMRdXQT8UDoLMN4HLyJSD+oi4IfTOQDam6a+cTenG55EJIbqIuAHw4Cf2EUzUVo3PIlIDNVFwM/URZPJatpgEYmfOgl4VfAiUn8U8KiCF5F4qouAH86c3kVz+fol44/TWVXwIhI/dRHwQ+kcZtDWmBzf9v0PvJRvve/FAGRyquBFJH7qJuDbm1LYpAloigt/qIIXkTiqi4AfTGdZNMUImuZwdklV8CISR3UR8MFMkqd/wKoKXkTirC4CfrhEwKuCF5E4i3TRbTN7BhgC8kDO3auyAPdQJktPe9Np24sVfEYVvIjEUKQBH7ra3QcqcJySTo7laW0sXcGnVcGLSAzVRRdNJlcYr9YnGu+iUQUvIjEUdcA78BMz22JmN0Z8rJLS2QJNDcnTtp/6kFUVvIjET9RdNFe5+wEzWwbcbWa73P2BiW8Ig/9GgLVr10bSiEwuP2UF35BMkEwYGc1FIyIxFGkF7+4Hwu+HgduBy6d4z03u3uvuvT09PZG0I5MtjHfHTNaUSqiCF5FYiizgzazNzDqKj4HXAdujOl4phYIzlp+6Dx6CfnhV8CISR1F20ZwF3B5OD5AC/tbd/ynC402pGN6q4EWk3kQW8O7+FHBJVPufreJNTNNV8JoPXkTiKPbDJIsVfFPD1KfalEpoPngRiaXYB3yx+6U5VaKLRhW8iMRU7AN+pgq+WRW8iMRU7ANeFbyI1KvYB/xMFXxHc4oTo2OVbJKISEXEPuDHK/gSwyTXL21j77GTZPOq4kUkXmIf8MWJxEoNk1zf3Ua+4Ow9OlrJZomIRC7+AT/DjU7re9oAeHpgpGJtEhGphNgHfLGLplQFv6FbAS8i8RT7gB//kLXEKJrO1kaWtDWyu18BLyLxEvuAP/Uha+lTXd/dxlP9w5VqkohIRcQ+4Geq4AFWLG6mfyhTqSaJiFRE7AN+pj54gNbGJKNjuptVROIl9gGfyRVoTCZIJKzke1obU4yO5SrYKhGR6NVBwOdL3sVapApeROIo9gGfzham7X+HIOBzBWdMc9KISIzEPuBLLbg9UUtjsO7JSVXxIhIj8Q/4bGHaIZIQVPAAo1n1w4tIfMQ/4HP5WXXRAIxkVMGLSHzEPuDTs6rg1UUjIvET+4CfSwWvoZIiEieRB7yZJc3sETO7M+pjTSWTm7mCbxkPeFXwIhIflajgPwzsrMBxppTOzlzBt4VdNAp4EYmTSAPezFYDbwC+HuVxppPJFWZ1oxOoi0ZE4iXqCv7PgU8CJe8gMrMbzazPzPr6+/vL3oB0Nl9ywe2iYhfNyawqeBGJj8gC3sx+Czjs7lume5+73+Tuve7e29PTU/Z2zKWC1zBJEYmTKCv4q4A3mdkzwHeBV5vZtyM83pTS2XzJ5fqKmlNJzOCkumhEJEYiC3h3/7S7r3b3dcANwH3u/q6ojleiDUEFP8NUBYmE0dKgCcdEJF5iPQ4+m3fcSy+4PVFrY5IRBbyIxEiqEgdx9/uB+ytxrInSuZkX+yhqbUypi0ZEYmXa5DOzTRMeN0167SVRNapcMtnicn2zCXh10YhIvMyUfH874fHPJ732F2VuS9mNL9c3iy6alsakhkmKSKzMFPBW4vFUzxecUwtuz66CHxge4/joWNTNEhGpiJmSz0s8nur5glOs4GfzIWs27+w8OMgNN/0i6maJiFTETB+yrjaz/01QrRcfEz5fFWnLymAuFfxrNi3j4aePsuu5IdwdswX/B4qIyLRmCvhPTHjcN+m1yc8XnExu9hX8B165ETP4wl27GM7k6GhuiLp5IiKRmjbg3f3mydvMrAs47u4LvotmLqNoAJZ1NAPQP5RRwItIzZtpmORni0MlzazJzO4DdgOHzOyaSjTwTGTGx8HPXMED9HQEI0EPD2Uia5OISKXMVNq+HXgifPwegr73HuCVwBcibFdZpMMKfqYFP4qKAd+vgBeRGJgp+cYmdMVcC3zX3fPuvpMK3QV7JsYr+Fn0wQP0tCvgRSQ+Zgr4jJm90Mx6gKuBn0x4rTW6ZpXHeAU/yz74ztYGGpJG/7ACXkRq30xV+EeAWwm6Zf6Xuz8NYGbXA49E27QzN9cK3szoaW/i8KACXkRq30yjaH4BbJpi+13AXVE1qlwyc6zgIeiHVwUvInEwbcCb2ceme93dv1ze5pRXOpcnmTBSybkF/P7j6QhbJSJSGTMl35eAdwFLgXagY9LXgpbJzrzYx2Td7U0MqIIXkRiYqQ/+UoLVmN4AbAG+A9xbCzc5QVDBz+Yu1okWtzZw4mRW0xWISM2btrx1923u/il33wx8A3gzsMPM3lSJxp2p+VTwi1saGMsVxkfgiIjUqlmlXzhM8kXARcA+4HCUjSqXdK4w5wq+s6URgBMns1E0SUSkYmb6kPV9BHezNhMMl3ybu9dEuANksvl5VfAAx0+OsXxxcxTNEhGpiJn64L8BPAY8S3An6+sm9ku7+4LuqsnkCrMeA1/U2RoE/IlRVfAiUttmCvirp9hW/IB1wX8CmT6jCl4BLyK1baaA7wRWu/tXAczsYYK7Wh34/el+0MyagQeApvA4t7r7H55pg+cikyvQ0Ty3KXOKAa8+eBGpdTOVt58E7pjwvBHoBV4F/N4MP5sBXu3ulwCbgevM7CXza+b8pLPzGyYJMKiAF5EaN1N52+jueyc8f9DdjwBHzKxtuh8Mx8oPh08bwq+Kjp8fy819mGRHU4pkwjiuPngRqXEzpV/XxCfu/qEJT3tm2rmZJc1sG8Gwyrvd/aEp3nOjmfWZWV9/f/8smjx7Y/kCjXMMeDNjUXNKXTQiUvNmSr+HzOzfTd5oZh8AHp5p5+Hc8ZuB1cDlZvbCKd5zk7v3untvT8+MvzPmJJsv0DiHeWiKOlsb9SGriNS8mbpoPgr8vZm9E9gabruM4IPTt8z2IO5+3MzuB64Dts+9mfOTzTsN8wj4xS0NquBFpObNNF3wYeBKM3s1cGG4+Ufuft9MOw7vfs2G4d4CXAP8yZk2eC6yucK8A/746FgELRIRqZxZjSEMA33GUJ9kBXCzmSUJuoK+7+53znEfZ2QsX6AhNffh+p2tDew5MhJBi0REKieydVXd/VGC+WuqZr598F2tjRwZUQUvIrVt7ulXI/IFp+DMq4umu72RoXSOdDYfQctERCojtgGfzQfT/c4n4Hs6mgBUxYtITYttwI+NB/zc++CLAd8/pJWdRKR2xTbgs7kg4Od6oxMEy/aBAl5Ealt8Az4fzIpwJl00WptVRGpZjAN+/n3wS9tUwYtI7YttwJ9JH3xjKkFna4MCXkRqWmwDvljBz2ccPEBPe5O6aESkpsU34HPz74OH4INWVfAiUstiG/DjXTTzGEUDwQetquBFpJbFNuCzZ9AHD7B8cTMHTqQpFCq6RomISNnEPuDn2we/vruNsVyB/cdPlrNZIiIVE/uAn28f/IbuYEXCpwY0q6SI1KbYBvzYGX7IuqGnHYAP3rKVSz9/d9naJSJSKZFNF1xt410085gPHoIZJTuaUwylc5AJFvCez7QHIiLVEtvEOtMuGjMbr+IBrfAkIjVHAT+NYj88wFEFvIjUmNgG/FjuzAP+vVeu47oLlwNwVHPDi0iNiW/Ah7NJzneYJMAlazr56GtfACjgRaT2xDbgx7to5vkha1FXWwMAxxTwIlJjIgt4M1tjZj81s51m9riZfTiqY00lW4YuGggW4AY4OpI94zaJiFRSlMMkc8DH3X2rmXUAW8zsbnffEeExxxUr+FTizCr4hmSCRc0pjulDVhGpMZFV8O5+0N23ho+HgJ3AqqiON9lY3mlMJjA7s4AHWNLWqAW4RaTmVKQP3szWAS8CHqrE8SCo4Oc70dhkXW2N6oMXkZoTecCbWTvwA+Aj7j44xes3mlmfmfX19/eX7bjZfGHeUwVPtrStUaNoRKTmRBrwZtZAEO63uPttU73H3W9y91537+3p6SnbsYMKvjyn19XaqD54Eak5UY6iMeAbwE53/3JUxyllLOdnNAZ+oqXh8n3pbL4s+xMRqYQoK/irgHcDrzazbeHX9REe73nK2Qf/kg1LyOadf909UJb9iYhUQmTDJN39QaA8CTsP5eyieenGpbQ3pfjJ44d49aazyrJPEZGoxfpO1nIFfFMqydWblnHPzkO4awk/EakNsQ34sbyXbRQNwKVrOxkYHtN4eBGpGbEN+GyuQGOZ+uABzl7aCsCeI6Nl26eISJTiG/Bl7KIBWLskCPi9RxXwIlIbFPCztLpLFbyI1JbYBvxY3ssa8M0NSZYvauZZVfAiUiNiG/DZfGHeC26XsnZpK88eHSnrPkVEohLrgC9nBQ9BP7wqeBGpFbEN+NGxPC0NybLu88KVizg0mOF//nhXWfcrIhKF2Ab84Mksi1oayrrPd7/kbN6yeSVf/eluBoYzZd23iEi5xTLgx3IFMrkCHU3lnYkhlUzwtt41AOw8eNrMxyIiC0osA34oHayf2tFc/ql2zl+xCFDAi8jCF9OAzwHQ0VzeLhoIVndavqiZHQcU8CKysMU84KOZLPOClYvYeXAokn2LiJRLLAN+MOyiKfeHrEXnr+jg14eH+Nj3tpHNFyI5hojImYplwEfZBw/wu1eczRsvXsltj+znn58s3zqyIiLlFMuAHwy7aBZF0AcPsLKzhS/9ziV0tjZw29b9kRxDRORMxTLgo+6DB2hMJXjjxSu5e8chToxmIzuOiMh8xTTgg8BtL/M4+MnecflaMrkCn71jOx/93jYyOS3KLSILR0wDPkdbY5JUmeeimeyClYu4cuNSfrjtALc/sp+te45HejwRkbmIZcAPnsxGMgZ+Kp+49jw2r+kEYOuzxypyTBGR2Ygs4M3sm2Z22My2R3WMUobSuUj73yd60dou/v6DV3HOsnb6njlakWOKiMxGlBX8/wOui3D/JQ1lshUL+KLes7vY+uxxCgWv6HFFREqJLODd/QGgKiVtUMFXpoum6LKzuzhxMsuvD+sOVxFZGKreB29mN5pZn5n19feX56ahSnbRFF15TjcADz45wOHBdEWPLSIylaoHvLvf5O697t7b09NTjv3RP5Shu72pDK2bvVWdLWzsaeMr9zzJFV+8l13PaTIyEamuqgd8uR0ZGWM4k+Pspa0VP/bLz+1hKJPDHe7YdqDixxcRmSh2Ab/nSLAo9rqlbRU/9ps2r2RDTxvnr1jEnY8e5KGnjuCuD11FpDqiHCb5HeDnwHlmts/M3h/VsSZ6ZiBYFLsaFfyla7u47+Ov4r1Xns2zR0d5+02/4Pt9eyveDhERgMg+iXT3d0S17+nsOTJCwmB1V+UDvui3L13NkrYm/uqBp/hvP9rJq85bxlmLmqvWHhGpT7HronnmyCirulpoTFXv1FLJBK+94Cz+9K0Xk80X+Mzt29VVIyIVF7uA33NkpCr971NZ193Gx197HvfsPMQH/mYL7/vWw+w9OlrtZolInYhVwG/be5zHDwyyaXlHtZsy7v0vW88Hr97IvbsOc/+v+/mL+3dXu0kiUidiE/Duzse+t43li5v50NXnVrs54xIJ4xPXbuKJz1/HDS9eyw+27uNrP9vNb/2ff+ZXe49Xu3kiEmOxCfinBkZ4amCE//Cqc1jcWtlpCmYjlUzwgVdsoKMpxRf/cRfb9w/y+Tt38IW7dvL0wMj4HPYiIuVS2fv5I/QvvxkA4GXhlAEL0bruNn7+6dew58gI9+06zBf/cRd9e47xgy37ODY6xh9cfz7/9uUbqt1MEYmJWAX86q4W1lZh/PtcNKYSnHtWB6u7WhkYzrCxp50v3LWT7vYm/vyeJ0kljDdtXsWStsZqN1VEalwsAv7QYJoHnxzgjZesrHZTZq2lMcln3nABAG/rXcOeo6O8/isP8Ll/2MGdjx7kw9ecy/JFzazvbot8ZSoRiaeaD3h359O3PUbenRtfUZvdG4mEsb67jYf+4Bp+/PhzfPLWR3n3Nx4GIJUwLl3bxbteejYXr1rM6q4WBb6IzErNB/zgyRyHh9J88tpNbOhpr3Zzzsjilgbe1ruGxS0NtDQkeW4wze7+Ye7ZcYj/9J1HALhk9WK+/PbNbOhuw8yq3GIRWchsId1h2dvb6319fXP+uWy+QNKMRCKegVcoOPftOswzR0b40k+eIJ0tkDC4cmM3v33ZKlZ1trKys5mVi1ti+99ARKZmZlvcvXeq12q+ggdoiHmXRSJhXHPBWQC8/qIV/HTXYZ49OsqtW/bx4PcGxt/X3d7EZWd3snlNF2+4aAWrulpIKvBF6lYsKvh6lcnl2Xv0JAeOn2TfsZP8y+4Bdh0cZHd/MGVyR1OKy9Z18eJ1S7hi/RIuWr2YplSyyq0WkXKKfQVfr5pSSc5Z1s45y4LPHt55xVoAnjw0RN+eYzy2/wQPP32U+594AoDGZIJVXS10tzdy5cZuXn5uN5es6Yz9X0Ai9UoVfB04OjLGw08fZeuzxzhw/CR7j53ksX3HKTg0pRKct7yDTcs72LR8EResXMRFqxbT1qTf/SK1YLoKXgFfp46PjvGvu4+wZc8xnnhuiJ0HBzkyMgZAwmBDTzsbe9p45QuWceHKRbzgrA5aGtW9I7LQqItGTtPZ2sj1F63g+otWAOFi5cMZHt8/yCN7j7Pr4CDb9w/y48cPAWAWLIN4wYqgyi9+X9bRpOGaIguUAl4AMDOWdTSzbFMzV29aBgSh/8yRUZ54bpBdYZX/2P4T/Oixg+M/t7StkfNXLBr/LKAY/s0NqvZFqk0BLyWZBXfYru9u47oXrhjfPpjOsuvgEI8fOMGOA4PsfG6Q7/ftZXQsD0AyYWzsaePcszo4p6d9PPzXd7cp+EUqSAEvc7aouYHL1y/h8vVLxre5O/uPn+TxA4M8vv8E2w8M8ti+E9z12EGKH/MkDNYsaWXtklZWdbawqrOFsxY1093RSE978H1pW1NVl1sUiZNIA97MrgO+AiSBr7v7/4jyeFI9ZsbqrlZWd7Vy7YXLx7ens3me6h/hN/3D7D48zO7+YfYeO8nOnYcYGB6bcl+LWxro6WhicUsD7U2pU1/NKdqaUnQ0Bd/bm1O0NCRpTCVoSBqNyUT4OHieTCRImpFMWvA9Mekr3GYGBiQsfKzPFCQmIgt4M0sCXwVeC+wDfmlmd7j7jqiOKQtPc0My6Jdfuei019LZPP1DGfqHMwwMZRgYHmNgOEP/UIaB4QyD6SzHR8fYd2yU4UyO4XSOkbAbKGpmYeCHj4NHQPjLwMaf2oTHwS+H8V8Pdmob4/uZ8HzCfuDUL5pT7z193xPbN7kNp342bMOEn60F9fyLdUlrI9//vZeWfb9RVvCXA79x96cAzOy7wJsBBbwAQfivWdLKmiWzn8O/UHBGxnKMZPIMpbOkswXG8gXGcgWy+eBrLBdsK7iTLwQ/kys4eXfy+QJ5h3yhQL4QfHcHB9yh4B4+9uc9J3zdCd8M4+879Rrj3VGOM3EEso/v9/TXT22fsK/isUrse3Ibnt/GCe2a9X/ZKquZhkajozmaKI4y4FcBeyc83wdcEeHxpA4kEkZHcwMdzQ0sX9xc7eaILGhRfpo11d9bp/2eNrMbzazPzPr6+/sjbI6ISH2JMuD3AWsmPF8NHJj8Jne/yd173b23p6cnwuaIiNSXKAP+l8C5ZrbezBqBG4A7IjyeiIhMEFkfvLvnzOxDwI8Jhkl+090fj+p4IiLyfJGOg3f3u4C7ojyGiIhMTbcMiojElAJeRCSmFPAiIjG1oBb8MLN+YM88f7wbGJjxXbVB57LwxOU8QOeyUM33XM529ynHmC+ogD8TZtZXalWTWqNzWXjich6gc1moojgXddGIiMSUAl5EJKbiFPA3VbsBZaRzWXjich6gc1moyn4usemDFxGR54tTBS8iIhMo4EVEYqrmA97MrjOzJ8zsN2b2qWq3Z67M7Bkze8zMtplZX7htiZndbWZPht+7qt3OqZjZN83ssJltn7CtZNvN7NPhdXrCzK6tTqunVuJcPmdm+8Nrs83Mrp/w2kI+lzVm9lMz22lmj5vZh8PtNXVtpjmPmrsuZtZsZg+b2a/Cc/mjcHu01yRYmqw2vwhmqdwNbAAagV8BF1S7XXM8h2eA7knb/hT4VPj4U8CfVLudJdr+CuBSYPtMbQcuCK9PE7A+vG7Jap/DDOfyOeA/T/HehX4uK4BLw8cdwK/DNtfUtZnmPGruuhAsgNQePm4AHgJeEvU1qfUKfnzdV3cfA4rrvta6NwM3h49vBt5SvaaU5u4PAEcnbS7V9jcD33X3jLs/DfyG4PotCCXOpZSFfi4H3X1r+HgI2EmwhGZNXZtpzqOUBXkeAB4YDp82hF9OxNek1gN+qnVfp/sfYCFy4CdmtsXMbgy3neXuByH4nxxYVrXWzV2pttfqtfqQmT0aduEU/3yumXMxs3XAiwgqxpq9NpPOA2rwuphZ0sy2AYeBu9098mtS6wE/q3VfF7ir3P1S4PXAB83sFdVuUERq8Vr9JbAR2AwcBP4s3F4T52Jm7cAPgI+4++B0b51i24I5nynOoyavi7vn3X0zwfKll5vZC6d5e1nOpdYDflbrvi5k7n4g/H4YuJ3gz7BDZrYCIPx+uHotnLNSba+5a+Xuh8J/lAXgrzj1J/KCPxczayAIxVvc/bZwc81dm6nOo5avC4C7HwfuB64j4mtS6wFf0+u+mlmbmXUUHwOvA7YTnMN7wre9B/hhdVo4L6Xafgdwg5k1mdl64Fzg4Sq0b9aK//BC/4bg2sACPxczM+AbwE53//KEl2rq2pQ6j1q8LmbWY2ad4eMW4BpgF1Ffk2p/ulyGT6evJ/h0fTfwmWq3Z45t30DwSfmvgMeL7QeWAvcCT4bfl1S7rSXa/x2CP5GzBBXH+6drO/CZ8Do9Aby+2u2fxbn8DfAY8Gj4D25FjZzLywj+nH8U2BZ+XV9r12aa86i56wJcDDwStnk78Nlwe6TXRFMViIjEVK130YiISAkKeBGRmFLAi4jElAJeRCSmFPAiIjGlgBcpAzN7lZndWe12iEykgBcRiSkFvNQVM3tXOC/3NjP7WjgB1LCZ/ZmZbTWze82sJ3zvZjP7RTip1e3FSa3M7Bwzuyec23urmW0Md99uZrea2S4zuyW8E1OkahTwUjfM7Hzg7QQTvG0G8sDvAm3AVg8mffsZ8Ifhj/w18PvufjHBnZPF7bcAX3X3S4ArCe6AhWC2w48QzOW9Abgq4lMSmVaq2g0QqaDXAJcBvwyL6xaCyZ0KwPfC93wbuM3MFgOd7v6zcPvNwN+FcwetcvfbAdw9DRDu72F33xc+3wasAx6M/KxESlDASz0x4GZ3//TzNpr910nvm27+jum6XTITHufRvy+pMnXRSD25F3irmS2D8fUwzyb4d/DW8D3vBB509xPAMTN7ebj93cDPPJiPfJ+ZvSXcR5OZtVbyJERmSxWG1A1332Fm/4VgBa0EwcyRHwRGgAvNbAtwgqCfHoLpW/9vGOBPAe8Lt78b+JqZ/XG4j9+p4GmIzJpmk5S6Z2bD7t5e7XaIlJu6aEREYkoVvIhITKmCFxGJKQW8iEhMKeBFRGJKAS8iElMKeBGRmPr/YemYL3UBSusAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# validation loss:\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure()\n",
    "plt.plot(epoch_val_gmse)\n",
    "plt.ylabel('GMSE')\n",
    "plt.xlabel('epoch')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 15:17:55 | model saved at: saved_model/L2G_WS50_unroll20.pt\n"
     ]
    }
   ],
   "source": [
    "# save trained model\n",
    "\n",
    "save_path = 'saved_model/L2G_{}{}_unroll{}.pt'.format(graph_type,\n",
    "                                                      graph_size,\n",
    "                                                      num_unroll)\n",
    "\n",
    "torch.save({'net_state_dict': net.state_dict(),\n",
    "            'optimiser_state_dict': optimizer.state_dict()\n",
    "            }, save_path)\n",
    "\n",
    "\n",
    "logging.info('model saved at: {}'.format(save_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test / Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 15:27:55 | GMSE: 0.05936715006828308 +- 0.0038578716206355414\n",
      "03-Jun-21 15:27:55 | aps: 0.9960138205245266 +- 0.0011615028706712285\n",
      "03-Jun-21 15:27:55 | auc: 0.9988567361111111 +- 0.0005779778762602031\n",
      "03-Jun-21 15:27:55 | layerwise test loss :[1.0435071, 0.75211203, 0.2969693, 0.29407477, 0.19469345, 0.14014131, 0.06299843, 0.07452516, 0.056948557, 0.0628472, 0.05281911, 0.058609165, 0.05863236, 0.048517235, 0.0485345, 0.048531204, 0.04853106, 0.048537016, 0.048606172, 0.066911146]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "for z, w_gt_batch in test_loader:\n",
    "    test_loss = []\n",
    "\n",
    "    z = z.to(device)\n",
    "    w_gt_batch = w_gt_batch.to(device)\n",
    "    this_batch_size = w_gt_batch.size()[0]\n",
    "\n",
    "    adj_batch = w_gt_batch.clone()\n",
    "    adj_batch[adj_batch > 0] = 1\n",
    "\n",
    "    w_list = net.validation(z, threshold=1e-04)\n",
    "    w_pred = torch.clamp(w_list[:, num_unroll - 1, :], min=0)\n",
    "\n",
    "    loss_mean = gmse_loss_batch_mean(w_pred, w_gt_batch)\n",
    "    loss_pred = gmse_loss_batch(w_pred, w_gt_batch)\n",
    "\n",
    "    layer_loss_batch = torch.stack([layerwise_gmse_loss(w_list[i, :, :], w_gt_batch[i, :]) for i in range(batch_size)])\n",
    "\n",
    "\n",
    "loss_all_data = loss_pred.detach().cpu().numpy()\n",
    "final_pred_loss, final_pred_loss_ci, _, _ = mean_confidence_interval(loss_all_data, 0.95)\n",
    "logging.info('GMSE: {} +- {}'.format(final_pred_loss, final_pred_loss_ci))\n",
    "\n",
    "aps_auc = binary_metrics_batch(adj_batch, w_pred, device)\n",
    "logging.info('aps: {} +- {}'.format(aps_auc['aps_mean'], aps_auc['aps_ci']))\n",
    "logging.info('auc: {} +- {}'.format(aps_auc['auc_mean'], aps_auc['auc_ci']))\n",
    "\n",
    "layer_loss_mean = [mean_confidence_interval(layer_loss_batch[:,i].detach().cpu().numpy(), confidence=0.95)[0] for i in range(num_unroll)]\n",
    "layer_loss_mean_ci = [mean_confidence_interval(layer_loss_batch[:,i].detach().cpu().numpy(), confidence=0.95)[1] for i in range(num_unroll)]\n",
    "logging.info('layerwise test loss :{}'.format(layer_loss_mean))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAnJklEQVR4nO3deXRddbn/8feTeehJkzZp05EWTJNWhE6CYFtAvQroBf05ASoqKrcqirhUcDlcvfq7PweUe72CLBSu1xEHEFFBUBEKFxnaUkpLB0rHdB6SNm3m5Pn9sXfb0/RkOKfZOTk5n9daWdnTd+8nyc5+zp6er7k7IiKSvXLSHYCIiKSXEoGISJZTIhARyXJKBCIiWU6JQEQky+WlO4BkVVZW+rRp09IdhohIRlm2bNk+d69KNC/jEsG0adNYunRpusMQEckoZralt3m6NCQikuWUCEREspwSgYhIllMiEBHJckoEIiJZTolARCTLKRGIiGS5rEkE63c38fU/vkhrR1e6QxERGVayJhHUNzTzoyc2sXRzQ7pDEREZVrImEbzm9LEU5Oaw5KW96Q5FRGRYyZpEUFKQx/xpFSxZr0QgIhIvaxIBwKIZVazd1cTuQ63pDkVEZNjIqkSwsKYSQGcFIiJxsioRzKwuo3JUIUte2pfuUEREho3IEoGZ3WVme8xsVS/zzcy+Z2YbzGylmc2NKpajcnKMRTWVPPHSXrq6PerNiYhkhCjPCH4MXNzH/EuAmvDrWuAHEcZyzKIZVTQ0d7B6x8Gh2JyIyLAXWSJw9yXAgT4WuRz4iQeeAsrNbEJU8Ry1QPcJREROkM57BJOAbXHj9eG0k5jZtWa21MyW7t17agfwylGFnDmpjCXrdZ9ARATSmwgswbSEF+7d/Q53n+/u86uqEna5mZSFNVUs39pAU2vHKa9LRCTTpTMR1ANT4sYnAzuGYsOLaqro7HaefHn/UGxORGRYS2ciuB+4Onx66DXAQXffORQbnndaBaUFubpPICIC5EW1YjP7JXAhUGlm9cC/AvkA7n478ABwKbABaAY+GFUsPRXk5XDeGWN5XO8TiIhElwjc/cp+5jvw8ai2359FM6r465o9bN53hGmVpekKQ0Qk7bLqzeJ4i2qCm86qRioi2S5rE8FpY0uYMqZY9wlEJOtlbSIwMxbVVPGPl/fT3tmd7nBERNImaxMBBPcJjrR3sWyLei0TkeyV1Yng/DPGkpdjPK77BCKSxbI6EcSK8pk7tUI3jEUkq2V1IgBYNKOSVdsPse9wW7pDERFJi6xPBAvDx0if0MtlIpKlsj4RnDlpNBUl+XqMVESyVtYngtwcY0FNFUte2ke3ei0TkSyU9YkAYFFNJfsOt7F2V1O6QxERGXJKBATvE4DKTYhIdlIiAMaXFVFXHdN9AhHJSkoEoYU1lSzd3EBze2e6QxERGVJKBKFFM6po7+rmqY3qtUxEsosSQejV08ZQlJ+jTu1FJOsoEYSK8nM5d/pY3TAWkayjRBBn0YwqNu49Qn1Dc7pDEREZMkoEcRbVVALo8pCIZBUlgjivGDeKCaOL9BipiGQVJYI4R3st+9+X99HZpV7LRCQ7KBH0sGhGFU2tnazY1pjuUEREhoQSQQ8LXlFJjsESlaUWkSyhRNDD6JJ8zp5SrvsEIpI1lAgSWFhTxcr6Rhqb29MdiohI5JQIErhgRiXdDk9s0OUhERn5lAgSOHtyObGiPF0eEpGsoESQQF5uDgteUcmS9ftwV69lIjKyKRH0YtGMKnYdamXDnsPpDkVEJFJKBL1YGJabeEyXh0RkhIs0EZjZxWa2zsw2mNlNCeaPNrM/mNnzZrbazD4YZTzJmFxRwulVpXqfQERGvMgSgZnlArcClwCzgCvNbFaPxT4OvOjuZwMXAt8xs4KoYkrWopoqnt64n9aOrnSHIiISmSjPCM4BNrj7RndvB+4GLu+xjAMxMzNgFHAAGDZ9RV4wo4q2zm6e2XQg3aGIiEQmykQwCdgWN14fTov3fWAmsAN4Abje3YdNtbdzTx9DQW6OHiMVkREtykRgCab1fBbzTcAKYCIwG/i+mZWdtCKza81sqZkt3bt36A7KJQV5vHp6hXotE5ERLcpEUA9MiRufTPDJP94HgXs9sAHYBNT1XJG73+Hu8919flVVVWQBJ7Kopor1uw+z62DrkG5XRGSoRJkIngVqzGx6eAP4CuD+HstsBV4PYGbjgVpgY4QxJW1hTZB4dFYgIiNVZInA3TuB64CHgDXAr919tZktNrPF4WJfA843sxeAvwE3uvuwel5z5oQYVbFC3ScQkRErL8qVu/sDwAM9pt0eN7wDeGOUMZwqM2NhTSWPrN1DV7eTm5Po1oeISObSm8UDcMGMKhqbO3hh+8F0hyIiMuiUCAZgwSsqMUOXh0RkRFIiGICxowo5c+JoHtcNYxEZgZQIBmhhTSXLtzZyqLUj3aGIiAwqJYIBOu+MsXR1Oyu36T6BiIwsSgQDNHNC8MLz2l2H0hyJiMjgUiIYoMpRhVSOKmDdrqZ0hyIiMqiUCJJQWx1j3W4lAhEZWZQIklBXXca6XU10dasfYxEZOZQIklBbHaOts5st+4+kOxQRkUGjRJCEuuoYgO4TiMiIokSQhJpxMXIM1igRiMgIokSQhOKCXKaNLWWdHiEVkRFEiSBJtdUx1uqMQERGECWCJNVVl7H1QDPN7Z3pDkVEZFAoESSptjqGO6zffTjdoYiIDAolgiTNnBA8ObR2p+4TiMjIoESQpCkVJZQU5Oo+gYiMGEoEScrJMWrGx/QugYiMGEoEKZhZHWPtrkO4q9SEiGQ+JYIU1FbHaGjuYG9TW7pDERE5ZUoEKagNS03oPoGIjARKBCmoq1YnNSIycigRpGBMaQHjYoU6IxCREUGJIEW11XpySERGBiWCFM2cUMZLew7T2dWd7lBERE6JEkGKasfHaO/sZrM6qRGRDKdEkKK6CXpySERGBiWCFL1i3Chyc4y1O5UIRCSzKRGkqDAvl+mVpTojEJGMp0RwCurCUhMiIpks0kRgZheb2Toz22BmN/WyzIVmtsLMVpvZY1HGM9jqqmPUN7RwuE2d1IhI5oosEZhZLnArcAkwC7jSzGb1WKYcuA24zN1fCbwzqniiUBu+Yaz3CUQkk/WZCMysLm64sMe81/Sz7nOADe6+0d3bgbuBy3sscxVwr7tvBXD3PQMNfDioO1ZzSJeHRCRz9XdG8Iu44X/0mHdbP20nAdvixuvDafFmABVm9qiZLTOzqxOtyMyuNbOlZrZ07969/Wx26EyuKGZUYZ7OCEQko/WXCKyX4UTjfbU9qmcB/zxgHvBm4E3Al8xsxkmN3O9w9/nuPr+qqqqfzQ4dM2PG+FF6ckhEMlp/icB7GU403lM9MCVufDKwI8Eyf3b3I+6+D1gCnN3PeoeVugllrN2pTmpEJHPl9TN/spl9j+DT/dFhwvGel3l6ehaoMbPpwHbgCoJ7AvF+D3zfzPKAAuBc4JYk4k+7uuoYv3i6k12HWpkwujjd4YiIJK2/RPDZuOGlPeb1HD+Bu3ea2XXAQ0AucJe7rzazxeH82919jZn9GVgJdAM/cvdVSf0EaVY7/nipCSUCEclEfSYCd/+fntPMrAJo9AFcC3H3B4AHeky7vcf4t4FvDyjaYehYJzU7m7iodlyaoxERSV5/j49++egjpGZWaGaPAC8Du83sDUMR4HA3uiSfCaOLWKdHSEUkQ/V3s/jdwLpw+P0E9waqgAuAf48wrowSlJrQk0Mikpn6SwTtcZeA3gTc7e5d7r6G/u8vZI3a6jJe3nuYDnVSIyIZqL9E0GZmZ5pZFXAR8HDcvJLowsosddUxOrqcjXvVSY2IZJ7+EsGngN8Ca4Fb3H0TgJldCjwXbWiZ43gnNbpPICKZp7+nhp4C6hJMP+lpoGx2euUo8nKMtbuaTiqmJCIy3PWZCMzs033Nd/fvDm44makgL4czqkap5pCIZKT+bvjeDKwAHgTa6L++UNaqmxDj2U0H0h2GiEjS+ksEcwlKQ7wZWAb8EvjbQF4myza11TF+v2IHB1s6GF2cn+5wREQGrM+bxe6+wt1vcvfZwJ0E/Qm8aGaXDUVwmeRo3wTrd+vykIhklgH1UBY+PjoHeBVBxdCM6kBmKBwvNaEnh0Qks/R3s/iDBG8XFxE8RvquTOtFbKhMGF1ErChPbxiLSMbp7x7BncALwFaCN4vfaHb8frG76xJRyMyYWV2mJ4dEJOP0lwguSjDt6I1iPUHUQ211jPue2467E58wRUSGs/4SQTkw2d1vBTCzZwiKzjlwY7ShZZ7a6hhNbZ1sb2xhcoUqcIhIZujvZvHngPvjxguA+cCFwOKIYspYM8NSE7o8JCKZpL9EUODu2+LGn3D3/e6+FSiNMK6MNCOutzIRkUzRXyKoiB9x9+viRqsGP5zMFivKZ1J5sRKBiGSU/hLB02b2kZ4TzexfgGeiCSmzzZwQ07sEIpJR+rtZfANwn5ldBSwPp80DCoG3RhhXxqqtjvH3dXtp6+yiMC833eGIiPSrvzLUe4Dzzex1wCvDyX9y90cijyxD1VaX0dXtvLznCLMmlqU7HBGRfg2ou8nwwK+D/wDMrD7eSY0SgYhkggHVGpKBm1ZZSkFujh4hFZGMoUQwyPJzczhj3Cg9OSQiGUOJIAIzq2Pqv1hEMoYSQQRqq2PsPtRGY3N7ukMREemXEkEE6iaEfRPo8pCIZAAlgggc7a1ML5aJSCZQIojAuFgh5SX5rFO3lSKSAZQIImBm1FXHdGlIRDKCEkFE6sLeyrq7vf+FRUTSKNJEYGYXm9k6M9tgZjf1sdyrzazLzN4RZTxDqbY6RnN7F/UNLekORUSkT5ElAjPLBW4FLgFmAVea2axelvsm8FBUsaRDXVypCRGR4SzKM4JzgA3uvtHd24G7gcsTLPcJ4B5gT4SxDDl1UiMimSLKRDAJiO/drD6cdoyZTQLeBtze14rM7FozW2pmS/fu3TvogUahtDCPqWNKVHNIRIa9KBOBJZjW887pfwA3untXXyty9zvcfb67z6+qypyO0eqqY6zRpSERGeYGVIY6RfXAlLjxycCOHsvMB+42M4BK4FIz63T3+yKMa8jUVcf465rdtHZ0UZSvTmpEZHiK8ozgWaDGzKabWQFwBXB//ALuPt3dp7n7NOC3wMdGShKAoNREt8OGPYfTHYqISK8iSwTu3glcR/A00Brg1+6+2swWm9niqLY7nNSGTw6tUakJERnGorw0hLs/ADzQY1rCG8Pu/oEoY0mHaWNLKcxTJzUiMrzpzeII5eYYM8bHVHNIRIY1JYKI1VbHWLNTiUBEhi8lgojVVcfYd7iN/Yfb0h2KiEhCSgQRq6sOOqnRfQIRGa6UCCJ27MkhJQIRGaaUCCJWFStkbGkB6/SGsYgMU0oEQ6BuQkyXhkRk2FIiGAK148tYt7uJLnVSIyLDkBLBEKibEKO1o5utB5rTHYqIyEmUCIbAsU5qVGpCRIYhJYIhUDMuhpk6qRGR4UmJYAgUF+QyfWypbhiLyLCkRDBEaqtj6r9YRIYlJYIhUlsdY8uBZl7ccYhtB5rZ29TG4bZOOru60x2aiGS5SMtQy3GvmjQad7j0e4+fNC8vxyjKz6UoPyf8Hg7n5Z4wXhUr5PrX11BeUpCGn0BERiolgiFyYe04fnLNOTS2dNDa0UVbRxctHV20dnTTGn5vCae3dh6ffqS9k/1H2mnr6GLrgWY27DnMf3/g1eTl6mRORAaHEsEQyc0xFs2oOqV1/PrZbXzunpV848G1fPEtswYpMhHJdkoEGeRdr57CizsP8aMnNjFzQhlvnzc53SGJyAig6wsZ5gtvnsl5p4/l8797gee3NaY7HBEZAZQIMkx+bg63vmcuVaMKufanS9lzqDXdIYlIhlMiyEBjSgv44dXzOdTSyeKfLaOtsyvdIYlIBlMiyFCzJpZx8zvPZvnWRr5832rcVdlURFKjRJDB3nzWBK676BX8auk2fvrUlnSHIyIZSokgw336n2bwhpnj+OofXuQfL+9PdzgikoGUCDJcTo5xy7tnM72ylI/9fBnb1OeBiCRJiWAEiBXl88Or59PV7Vz702U0t3emOyQRySBKBCPE9MpSvnflHNbtOsRnf7NSN49FZMCUCEaQC2vHcePFdfzphZ3c9ujL6Q5HRDKEEsEIc+2i07l89kRufngdf1uzO93hiEgGUCIYYcyMb779LF45sYzr717Bhj3qFU1E+qZEMAIV5edyx/vmU5Sfw0d+soyDLR3pDklEhrFIE4GZXWxm68xsg5ndlGD+e8xsZfj1pJmdHWU82WRieTE/eO886hua+eQvn6OrWzePRSSxyBKBmeUCtwKXALOAK82sZxH9TcAF7n4W8DXgjqjiyUavnjaGr152Jo+t38u3H1qX7nBEZJiKsj+Cc4AN7r4RwMzuBi4HXjy6gLs/Gbf8U4AK7A+yq86dyos7D3L7Yy8zc0KMy2dPSndIIjLMRHlpaBKwLW68PpzWmw8BDyaaYWbXmtlSM1u6d+/eQQwxO3z5La/knGlj+NxvV/L7Fdv1joGInCDKRGAJpiU8ApnZRQSJ4MZE8939Dnef7+7zq6pOrbvHbFSQl8Nt75177Emij/5sOfsOt6U7LBEZJqJMBPXAlLjxycCOnguZ2VnAj4DL3V1V0yJSOaqQ3yw+n5suqeORtXt40y1LePCFnekOS0SGgSgTwbNAjZlNN7MC4Arg/vgFzGwqcC/wPndfH2EsAuTmGIsvOIM/fnIBE8uL+ejPl3P93c/R2Nye7tBEJI0iSwTu3glcBzwErAF+7e6rzWyxmS0OF/syMBa4zcxWmNnSqOKR42aMj3Hvx87nhjfM4E8rd/LGW5bwyFq9hSySrSzTbhzOnz/fly5Vvhgsq7Yf5DO/eZ61u5p457zJfOmfZ1FWlJ/usERkkJnZMnefn2ie3izOcmdOGs3vr3stH7/oDO5ZXs/FtyzhiZf2pTssERlCSgRCYV4un31THfd89HyKCnJ5751P88X7XuBIm/o1EMkGSgRyzJypFTzwyYV8eMF0fv70Vi75z8d5ZtOBdIclIhFTIpATFOXn8sW3zOJX154HwLvv+Adf++OLtHZ0pTkyEYmKEoEkdM70MTx4/ULee+5p3PnEJi793uM8t7Uh3WGJSASUCKRXpYV5fO2tZ/KzD51LW0c3b//Bk/zq2a3pDktEBpkSgfRrQU0lf/7UQhbWVHHjPS/w4//dlO6QRGQQKRHIgMSK8rnj6nm86ZXj+cofXuS2RzekOyQRGSRKBDJghXm5fP+quVw+eyLf+vM6vvPwOlUyFRkBouyPQEag/Nwcvvuu2RTn5/Jfj2ygub2LL755JmaJis2KSCZQIpCk5eYY//62V1GUn8udT2yipaOLr19+Jjk5SgYimUiJQFKSk2P86z/PoqQgl9sefZnW9i6+9Y6zyMvV1UaRTKNEICkzMz53cR0lBbnc/PB6Wjq6+M8r5lCQp2Qgkkn0Hyun7LrX1fClt8ziwVW7+JefLtVbyCIZRolABsWHFkzn39/2Kh5dv5drfvysCtaJZBAlAhk0V507le++62ye2rifq+96hkOtHekOSUQGQIlABtXb5kzm1qvmsrK+kat++BQHjqgbTJHhTolABt0lr5rAHe+bz/rdh7nijn+wp6k13SGJSB/UVaVE5skN+/jwT5YyvqyIn3/4XCaWF/fbpqOrm/qGFjbvO8Lm/UfYvO8Im/Y3s2X/EXYebOX0ylLmTK1gztRy5k4t5/TKUXp/QWQA+uqqUolAIrVsywE+cNezlBXn84uPnMtpY0sTHuw3729m8/4j1De00NV9fJ8cVZjHtMoSpo0tZVysiJf2NLFiWyNNrcHN6FhRHrOnlDNnSjlzplYwe0o5FaUF6fpxRQZdW2cXq3ccYvmWBmZNLOP8MypTWo8SgaTVqu0Hed+dTwNQVpzf58F+2thSplWWMm1sCdMqSxlbWnBS+YrubmfjvsMs39rIim2NPLe1kXW7DnF0ldMrS5kzpZzZU8uZM6WCugkx8vt50a29s5vGlnYamztoONJOQ3MHB1uC7w3N7Rxs7qCptZOxowqYXFHM5IqSY98rSvJVYkMGzd6mNpZvbWD5lgaWbWlg5faDtHd2A/Avi07n85fOTGm9SgSSdut2NfGNB9dQWpg3oIN9so60dbKy/mCYGBp4blsje5vaACjMy+GsyaN55cTRdHU7Dc3BAb+xpZ2GIx00NrdzpL33dx/yc43ykgJihXnsbWqjqcejsSUFuUyuKGZS+YkJIvhezJhefj53p6Wji4bmIIbG5o5jcTWG0xrCaQdb2jnY0kFnd+r/rwaMixWdFN/kMSVUlxWRq0tsQ66r23lpTxPLwoP+si0NbNnfDEBBbg5nTipj3mkVzDutgrlTKxhXVpTytpQIJOu4O9sbW46dMTy3tYE1O5soys+hvKSA8pJ8KkoKKC/Op7ykgIqSfMpLjg4XhMPBMiUFuSccyA+2dFDf0Ex9QwvbG1qob2g5Nl7f0Myh1hMTRXF+LpMqipkwuoi2juDMo6G5g4PNHbR3dff6MxTn5x6Lqbw4n9HF+eTlpn6w7nZnz6E2tjU0s/tQ2wnz8nKMieXFx5NDj4Q2XoliUDS1dvD8toMs3XKAZVsaWLG18dgHi8pRBcydWsH8acGB/5UTR1OUnzto21YiEBlCh1o7TkoQ2xta2HmoleL8HMqLC44f4EvyqSjJZ3RxXHIqCQ76g3kQ6Kmts4sdja0nJLD6MOZtB5rZ03RiosjPNapHF1GYF11Mw13PY+VJR07vc5TO7m62N7TQ7WAGteNjxz7tzzutgqljSiK9xNhXIlCtIZFBVlaUT9mEfGZOKEt3KL0qzMtlemUp0ytLE85v7ehiR2PLseRQ39DM9sYWOrsy64PjoLM+R086kPec//a5k5l3WvBQQ6wof9DDS5USgYicpCg/l9OrRnF61ah0hyJDQC+UiYhkOSUCEZEsp0QgIpLllAhERLJcpInAzC42s3VmtsHMbkow38zse+H8lWY2N8p4RETkZJElAjPLBW4FLgFmAVea2awei10C1IRf1wI/iCoeERFJLMozgnOADe6+0d3bgbuBy3sscznwEw88BZSb2YQIYxIRkR6iTASTgG1x4/XhtGSXERGRCEX5Qlmid6V7vpY4kGUws2sJLh0BHDazdSnGVAnsS7Gt2g+PGNRe7dU+Naf1NiPKRFAPTIkbnwzsSGEZ3P0O4I5TDcjMlvZWa0PtMyMGtVd7tT+1/+FEorw09CxQY2bTzawAuAK4v8cy9wNXh08PvQY46O47I4xJRER6iOyMwN07zew64CEgF7jL3Veb2eJw/u3AA8ClwAagGfhgVPGIiEhikRadc/cHCA728dNujxt24ONRxtDDqV5eyvb2wyEGtVd7tR9kGdcfgYiIDC6VmBARyXJKBCIiWS4rEoGZ3WVme8xsVYrtp5jZ381sjZmtNrPrk2xfZGbPmNnzYfuvphhHrpk9Z2Z/TKHtZjN7wcxWmFnSfX2aWbmZ/dbM1oa/h/OSaFsbbvfo1yEz+1SS278h/N2tMrNfmllSvXib2fVh29UD2XaifcbMxpjZX8zspfB7RZLt3xluv9vM+nwEsJf23w5//yvN7HdmVp5k+6+FbVeY2cNmNjGZ9nHzPmNmbmaVSW7/K2a2PW4/uDTZ7ZvZJyyoX7bazL6V5PZ/FbftzWa2Isn2s83sqaP/Q2Z2Tm/t+1jH2Wb2j/B/8Q9mlrAbu96OOcnsg0lx9xH/BSwC5gKrUmw/AZgbDseA9cCsJNobMCoczgeeBl6TQhyfBn4B/DGFtpuBylP4Hf4P8OFwuAAoT3E9ucAu4LQk2kwCNgHF4fivgQ8k0f5MYBVQQvCAxF+BmmT3GeBbwE3h8E3AN5NsPxOoBR4F5qew/TcCeeHwN1PYflnc8CeB25NpH06fQvAk4Ja+9qdetv8V4DMD/Jslan9R+LcrDMfHJRt/3PzvAF9OcvsPA5eEw5cCj6bwMzwLXBAOXwN8rZe2CY85yeyDyXxlxRmBuy8BDpxC+53uvjwcbgLWkEQpDA8cDkfzw6+k7tKb2WTgzcCPkmk3GMJPLYuAOwHcvd3dG1Nc3euBl919S5Lt8oBiM8sjOKCf9OJhH2YCT7l7s7t3Ao8Bb+urQS/7zOUECZHw+1uTae/ua9x9QG/F99L+4TB+gKcIXsBMpv2huNFS+tgH+/ifuQX4XF9t+2k/IL20/yjwDXdvC5fZk8r2zcyAdwG/TLK9A0c/wY+mn32wl3XUAkvC4b8Ab++lbW/HnAHvg8nIikQwmMxsGjCH4FN9Mu1yw1PRPcBf3D2p9sB/EPwDdifZ7igHHjazZRaU7EjG6cBe4L/DS1M/MrPEvZ737wr6+AdMxN23AzcDW4GdBC8ePpzEKlYBi8xsrJmVEHyam9JPm0TGe/jCY/h9XArrGCzXAA8m28jM/q+ZbQPeA3w5ybaXAdvd/flktxvnuvDy1F0pXNaYASw0s6fN7DEze3WKMSwEdrv7S0m2+xTw7fD3dzPw+RS2vQq4LBx+JwPYD3sccyLZB5UIkmBmo4B7gE/1+HTVL3fvcvfZBJ/izjGzM5PY7luAPe6+LJlt9vBad59LUPr742a2KIm2eQSnuD9w9znAEYLT0qRY8Ib5ZcBvkmxXQfBJaDowESg1s/cOtL27ryG4lPIX4M/A80Bnn42GMTP7AkH8P0+2rbt/wd2nhG2vS2KbJcAXSDJ59PAD4AxgNkFC/06S7fOACuA1wGeBX4ef7pN1JUl+GAl9FLgh/P3dQHiGnKRrCP7/lhFc8mnva+FTOeYkQ4lggMwsn+AP8nN3vzfV9YSXVB4FLk6i2WuBy8xsM0E579eZ2c+S3O6O8Pse4HcEZcIHqh6ojzuL+S1BYkjWJcByd9+dZLs3AJvcfa+7dwD3AucnswJ3v9Pd57r7IoLT9WQ/DQLstrBMevi910sTUTGz9wNvAd7j4YXiFP2CXi5L9OIMgkT8fLgfTgaWm1n1QFfg7rvDD0TdwA9Jbh+EYD+8N7zU+gzB2XGvN6wTCS8t/h/gV0luG+D9BPseBB9mko0fd1/r7m9093kEyejlPmJNdMyJZB9UIhiA8FPHncAad/9uCu2rjj7hYWbFBAe2tQNt7+6fd/fJ7j6N4NLKI+4+4E/EZlZqZrGjwwQ3HQf8BJW77wK2mVltOOn1wIsDbR8n1U9iW4HXmFlJ+Ld4PcE10wEzs3Hh96kEB4JU4rif4GBA+P33KawjZWZ2MXAjcJm7N6fQviZu9DKS2wdfcPdx7j4t3A/rCW5m7kpi+/F9jbyNJPbB0H3A68J1zSB4aCHZSpxvANa6e32S7SC4J3BBOPw6UvgwEbcf5gBfBG7vZbnejjnR7IODccd5uH8R/NPvBDoIduAPJdl+AcE19pXAivDr0iTanwU8F7ZfRR9PKwxgXReS5FNDBNf4nw+/VgNfSGG7s4Gl4c9wH1CRZPsSYD8wOsWf+6sEB65VwE8JnxxJov3jBMnreeD1qewzwFjgbwQHgL8BY5Js/7ZwuA3YDTyUZPsNBP13HN0H+3rqJ1H7e8Lf30rgD8CkVP9n6OcptF62/1PghXD79wMTkmxfAPws/BmWA69LNn7gx8DiFP/+C4Bl4T70NDAvhXVcT/AE0HrgG4TVHRK0TXjMSWYfTOZLJSZERLKcLg2JiGQ5JQIRkSynRCAikuWUCEREspwSgYhIllMikGHHzB61fqpzDtJ2PhlWd0z6Dd1T2OaPzewd4XC/P6eZnWdmPzSz+Wb2vXDahWaW1At1/WxjmpldFTd+bFuSHSLtqlJkqJlZnh8vzNafjxFUk9yUxhj6czHwZ3dfSvAeBwTvkhwGnhykmKYBVxG8bUyPbUkW0BmBpCT8FLkm/LS62oL69sXhvGOfdM2sMixJgJl9wMzuC+uwbzKz68zs02Ehu6fMbEzcJt5rZk9a0IfAOWH70rBY2bNhm8vj1vsbM/sDQangnrF+OlzPKgv7IjCz2wletLvfzG7osfwHzOz7ceN/NLMLw+HDYeG258OYx4fTf2xm3zWzvwPftOO164/2HdBX3wW5YftVFtSpj4/n9cBfw7OAP1pQgGwxcIMFdfEXhm+u3xP+Xp41s9eG6/2Kmd1hZg8DPwn/Zo+b2fLw6+hZxTcIirmtsKDfhwst7PPCgvr394U/x1Nmdlbcuu8K/9YbzeyTcX+jP4W/n1Vm9u7efm4ZPnRGIKeiBrjS3T9iZr8mqF3TXw2kMwkqKRYRvCl7o7vPMbNbgKsJqqwClLr7+RYUx7srbPcFgvIa11hQsuMZM/truPx5wFnufkLZXzObB3wQOJegX4inzewxd19sQcmGi9w9mTIFpQQlrb9gQccoHwG+Hs6bAbzB3bvMbCXwCXd/zMz+DfhXguqVicwmeMv3zDDm8vB7JdDh7gctrK3m7pvDJHbY3W8Ol/sFcIu7P2FBCY2HCEpvA8wDFrh7iwWF4/7J3VstKDfxS2A+QQHBz7j7W8L1XRgX21eB59z9rWb2OuAnYbwAdQR9BMSAdWb2A4IzmB3u/uZwXaMH+HuVNFIikFOxyd1XhMPLCC4x9OfvHtRXbzKzgwSlDiAoPXBW3HK/hKCmu5mVhQfHNxIU3/tMuEwRMDUc/kvPJBBaAPzO3Y8AmNm9BGWInxtArIm0A0d7iFsG/FPcvN+ESWA0Qcc9j4XT/4e+K65uBE43s/8C/sTxs5o3kuAMJ4E3ALPseCHOMgtrSwH3u3tLOJwPfN/MZgNdBImrPwsIi9O5+yMWlPI+enD/kwd9A7SZ2R5gPMHf8WYz+yZBKZTHB7ANSTMlAjkVbXHDXUBxONzJ8cuOPbuUjG/THTfezYn7Y8/aJ07wif7t3qNzFzM7l6A0diKplCmOjx9O/Bk6/Hhdli5OjLm3GPrk7g1mdjbwJuDjBJ2mXENQrXUgRQ5zgPPiDvgAhIkhPqYbCGocnR22aR3AuhP9/o7+/D3//nnuvj48C7sU+H9m9rC7/9sAtiNppHsEEoXNBJckAN6R4jreDWBmCwg6ojlIcMnjExYe4cxszgDWswR4qwWVS0sJCr/19yl1MzDbzHLMbApJlhsOY20ws4XhpPcR9IqWUHgJKMfd7wG+BMwNf8azCIqN9dREcDnmqIeJ61sg/MSfyGhgpwdloN9H0G1oovXFW0LQic3RS0b7vI+6+Bb0g9zs7j8j6LwllXLlMsR0RiBRuJmg05D3AY+kuI4GM3uSoGvAa8JpXyO4h7AyPFBuJqjN3yt3X25mPwaeCSf9yN37uyz0vwR9JL/A8UqXyXo/cHt4XX4jwX2K3kwi6P3t6AezzxMk0ufizj7i/QH4rQU3yz9B0P/wreF9iTyCg/fiBO1uA+4xs3cCf+f42cJKoNPMnieozhn/+/lKGNtKoJnjJZB78yqCXry6CapufrSf5WUYUPVRkWHIzL4IbHD3u9Mdi4x8SgQiIllO9whERLKcEoGISJZTIhARyXJKBCIiWU6JQEQkyykRiIhkuf8PAj99JLCvj/UAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(np.arange(1,21,1), layer_loss_mean)\n",
    "plt.xticks(np.arange(1,21,1))\n",
    "plt.ylabel('GMSE')\n",
    "plt.xlabel('number of unrolls/iterations')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "03-Jun-21 15:27:59 | results saved at: saved_results/L2G_WS50_unroll20.pt\n"
     ]
    }
   ],
   "source": [
    "result = {\n",
    "    'epoch_train_gmse': epoch_train_gmse,\n",
    "    'epoch_val_gmse': epoch_train_gmse,\n",
    "    'pred_gmse_mean': final_pred_loss,\n",
    "    'pred_gmse_mean_ci': final_pred_loss_ci,\n",
    "    'auc_mean': aps_auc['auc_mean'],\n",
    "    'auc_ci': aps_auc['auc_ci'],\n",
    "    'aps_mean': aps_auc['aps_mean'],\n",
    "    'aps_ci': aps_auc['aps_ci'],\n",
    "    'layerwise_gmse_mean': layer_loss_mean,\n",
    "    'layerwise_gmse_mean_ci ': layer_loss_mean_ci\n",
    "}\n",
    "\n",
    "\n",
    "result_path = 'saved_results/L2G_{}{}_unroll{}.pt'.format(graph_type,\n",
    "                                                          graph_size,\n",
    "                                                          num_unroll)\n",
    "\n",
    "\n",
    "with open(result_path, 'wb') as handle:\n",
    "    pickle.dump(result, handle, protocol=4)\n",
    "\n",
    "logging.info('results saved at: {}'.format(result_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAELCAYAAADuufyvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA4wElEQVR4nO2deZgcVdX/P2dmkkCAhE0WCQIqCPgqAWLCK/wEDEJYBFSUHVk0gqwiCAqIgCDIjixJhICsQUQQ2UJk81UWEyACYVFAkBiQVQIEkszM+f1RNdjpe2rq1nTXTHdzPs9Tz/ScultV3z5dfe693yuqiuM4jjPwtA10AxzHcZwEd8iO4zgNgjtkx3GcBsEdsuM4ToPgDtlxHKdBcIfsOI7TILhDdkpFRJ4XkS3S1z8SkYv7WM4sEdmsnm1znEajY6Ab4Hx4UNVTYtKJyGXAbFU9tiLvp8tql+M0Cv6E7EQjIv4F7jgl4g7Z6Qkr/FBEnhCRN0XkUhFZTEQ2E5HZInKUiLwMXCoibSJytIg8KyKvi8ivRWTZirL2FJEX0nPHVNXzExG5suL/TUTkPhH5j4i8KCJ7i8h4YHfgByLyjoj8vqKNPaGPISJyjojMSY9zRGRIeq6nzd8XkVdE5CUR2acfbqPj1Iw7ZKeH3YGtgE8AawE94YKVgGWB1YDxwCHAjsCmwEeBN4ELAERkXeAiYM/03HLACKsyEfkYcBvwC+AjwEhgpqpOAq4Cfq6qS6rql43sxwAbpXnWA0ZXtLenzcOBVYD9gAtEZJn4W+E4A4M7ZKeH81X1RVV9AzgZ2DW1dwPHq+p8VX0P+A5wjKrOVtX5wE+AndJwxk7Azar6x/TccWl+i92BP6jqNaq6UFVfV9WZkW3dHThRVV9R1VeBE0i+BHpYmJ5fqKq3Au8An4os23EGDI8JOj28WPH6BZInXIBXVfX9inOrATeISKWj7QJWTPN8UI6qvisir2fUtyrwbB/b+tG0jVZ7AV5X1c6K/+cBS/axLsfpN/wJ2elh1YrXHwPmpK+r5QBfBLZW1aUrjsVU9V/AS5XliMhQkrCFxYsk4RGLPAnCOSRfDFZ7HadpcYfs9HCgiIxIB+h+BFybkW4CcLKIrAYgIh8RkR3Sc78BtksH6wYDJ5Ldx64CthCRb4hIh4gsJyIj03P/Bj7eS1uvAY5N614e+DFwZS/pHacpcIfs9HA1cAfwXHr8NCPducBNwB0i8jbwADAGQFVnAQemZb1EMuA32ypEVf8JbAN8H3gDmEkyQAdwCbBuOvviRiP7T4EZwKPAY8DDvbTXcZoGcYF6R0SeB76lqn8Y6LY4zocZf0J2HMdpENwhO47jNAgesnAcx2kQ/AnZcRynQXCH7DiO0yDkrtQTkbWBHUh0AZRkAv5NqvpkZB1BTGSdYcOCRLttuaWZ+djrfh0WqF2Bra1tUGRziqEarvwVKed7rLt7fmBraxtipMwKM0ld21Mcq13ltMnqAyLtpdRVO+F96e7uDGxtbUUWztb/vi66uLGipmiRP7tffmWDDQPbkEH253XKgw/WfGGriUTHYV9QHegPzSL0eqdF5CgSTYMpwF9S8wjgGhGZoqqnltw+x3GcQgwaPHigm9Bn8r769gM+raoLK40ichYwCzAdciqhOB5g4sSJjB8/vg5NdRzHyWfQEOtXZXOQ55C7CYVcAFYmW8WLVEJxEsA6w4bp2Uccscj5J+fONfP9aPvtA9uCBa8Gto6OpQJbV9dC8+d9sZ+x4S+dqu+inhZk5A9//RQJb9g/Y63OJXR3LzDShk8Gs26bZNb16a3jviStkA3Y12W3yf5w2D+Pw/uXVb+Vv7198ZryW/cPoLs77ANWvxKxf/1a+a36u7vjQ1FlhOisdgK0t8eGLOx+ef2MBwPbHp//f0WaVohWfkI+DLhTRP7Of1W8PgZ8Ejiong2xnHERsj74rYjV6YsQ64ydbCflhNTaL+vFYkOHDnQT+kyvDllVbxeRtUgEwFch+aqeDUxXa1TFcRxngOloa97JY7m/RTT5zfdAP7TFcRynZtrbG3W2TT4uUO84TkvR0cQOufSl0yfttFNQwXsLwljTKTfdZOZfsOBNwxq2efDgZY10tdOf85Af/EU4aWXMwUcbKftvHvLdxx1n2jc/6STD2tzzkG8+8sjAtt3pp9dUZhbWoF52+/tnqmzt85DjeenF35n2lVfdoeaLHbv22tFO7c6nnmqeeciO4zjNxuCO5nVrzdtyx3Ecg/ZWHtRzHMdpJtqa2CGXHkNWDWe7W4s9ki3YQgYPXiawbT9yZGBbZVk7hnzG738Z2IYOXcNMa6116eqy9CWKLAwJ02YtIFi4MFwwM2jQcKNN88z87e1x8y+LxAqz55aG12Cl7eiw29TV9X5gs2LzxRaGhHVlzc609VDCuezd3WE7ISuuar+v9iKQcH5zkX716JSJgW393Q/LyG+1KbyvnZ3vmGkHDQq1Z+xrip+H3EtfrTmm+5UNNoh2ajc8/HBzxZBFZDSgqjpdRNYFxgFPqeqtpbfOcRynIG0ZDzzNQJ640PHA1kCHiEwj2czyHuBoEVlfVU8uv4mO4zjxNHMMOa/lOwEbA18g2U14R1U9EdgK2Dkrk4iMF5EZIjJj0iRbS8FxHKcMOtrbo49GIy9k0ZkukZ4nIs+q6lwAVX1PRKLEhawYsuM4Tlk086BenkNeICJDVXUe8IHKtIgMpxe1t0qswRNbrc0ePLEG8G6aOTOwvfQve2HJYdvsG9gm3XOnmdYeQLIU4LIu3RrUs/Lb31FWuU/+4eLAtuZmu2TUH0e2gluRtFanD9Nm36u4+5o9qGfdw1hbkfc663nCWjCUNahXa12hfeRuh0TWk9Uuq66oj3RmXUU+F2XS3qoxZOALqjofPtC06GEQ8M3SWuU4jtNHWvYJuccZG/bXgNdKaZHjOE4NtOwsC8dxnGajEQfrYindIVs7G6iGPymyJopbCz6sePHKq9gC98OWmBDYDt5yazPt+dOmGdYw1pYtAhNeVxEhImvHinW2+FZgq1UIPGtnjCJpY68r617Z5cbGOrN28bAW4RR5r8K02dcfv2OIVZf1GcgW8bH6YG0/y4tcq0WRtP2NPyE7juM0CNlfjI2PO2THcVqKlh3UcxzHaTaaOWRRurgQ2WrqUbz77rOBzZpbPGyJJcz8Z94aSm5kbai6sCucM336LbfkNbFu2KLl/p1p0Z8bB5RBGQL7tZIlxNTP7arZmx621VbRPuecqVMbynvnaVmMAZ5U1bkisjhwNLAB8ARwiqq+1Q9tdBzHiaaZ99TLe6SYDPRoPZ4LDAdOS22Xltgux3GcPtEmEn3kISLjRORpEXlGRIL91ERkdxF5ND3uE5H1Ks49LyKPichMEZkR1fa88/rf39GjVPUwVf2Tqp4AfLyXi3BxIcdxBoR6OWRJYjUXkChergvsmkoQV/IPYFNV/SxwEqmGTwWbq+pIVR0V0/a8AOXjIrKPql4K/FVERqnqDBFZCwgngqZUigtRYwzZcRynCHWc9jYaeEZVn0vLnQLsQBKyBUBV76tI/wAwopYK8xzyt4BzReRYkqXS94vIi8CL6bnSsXb3sMSBshZ7WAN4WTtcf2fzzQPb6bvuGtiOvOYaM3+t2ANVpVRVEv2363Tzf883Yvtra1MxIaryKDLLQkTGA+MrTJPSB0qAVUh8XQ+zSTThs9gPuK3ifwXukERhbGJFuZnkaVm8BewtIkuRhCg6gNmq+u+8gh3HcQaCIg656td8NdFLR0VkcxKHvEmFeWNVnSMiKwDTROQpVf1jb+2JmlOlqm8Df41J6ziOM5DUcceQ2cCqFf+PAOZUJxKRzwIXA1ur6us9dlWdk/59RURuIAmB9OqQm2fipuM4TgQiEn3kMB1YU0TWkGQX5l2AReKdIvIx4LfAnqr6twr7EmlkARFZAtgSeDyvwgFadVAk1hgnhG0LA8GR224b2KxYMcDEu+8ObO+9NzuwWbtDg707bxGyBb6r0/XfBP77z7C3TfzfI44JbHYMPKtNsX0gS2A+bmHF2XvsYeY/7IrLjfzxu14XiYvaYu7WexhfZhlx2SIbF2SUEF1umXHleq3UU9VOETkImAq0A5NVdZaI7J+enwD8GFgOuDB18J3pjIoVgRtSWwdwtarenlenLwNzHKelqKe4kKreCtxaZZtQ8fpbGBMc0pkZ61Xb83CH7DhOS9HMWhbukB3HaSnqOKjX75TukO34mSUuHgrZA3R1WbtIxW9maYkDWXOLwY4XL754OM97/vxXzPxf2zCMTV//0ENmWovu7vBa29sXi0qXpLVF/qspEoMec/gPosuw22+3yRbZD59ssjc5jRNiOujSCzPqj73X9mYARQTqu7vDttqfgfhNRtvbF49Oa8+lvzqwZfUrS4w+9v1PSwgs1r2uFy2thywinwC+QjL9oxP4O3CNCws5jtOINHPIotdnexE5BJgALAZ8DlicxDHfLyKbld04x3GcotRx2lu/kxds+TYwTlV/CmwBrKuqxwDjgLOzMrm4kOM4A0U91d76m5gYcgfQBQwBlgJQ1X9KVtCXRZcjqnY34qJ9x3FalEZ88o0lzyFfDEwXkQeAL5BoISMiHwHeiKnAngBu7Q5sP6y3tYVpiy1ACMkSB7IWfFgDeEOGrGDm/+TKKwe2U77xjcD2w2unmPljdwepdReRIvfKuv9pKUa58e2y2jD5O98NbPtOzBqUs8oM+9CgQUtl5A8H2i7b/4DAtteFvzDz24N6dh9OtGWqKbKbeRydne+a9u9fFS6CefvtpwLb0KGrR9dltfWVOfbirJVGhIuzyqRlZ1mo6rki8gdgHeAsVX0qtb9K4qAdx3EailZ+QkZVZwGz+qEtjuM4NdPSDtlxHKeZaMTBulgGZNfpImIj1sR8K781eb23cmP52oYbBjYrVgxw2s03B7ajttsusC275JJm/iOumhzYrIUV3d32Zi1tbZnjrDWQ1T/CTm+1K6tNdh+wFobYi1gse1vbECtldH4rBm4tQMlKm4VVhhXDzo7X11Z/LF1d75v22EUcWf3Sel97aX/N3vSiffeNdmoHTJ7cUN7bn5Adx2kpPGThOI7TILS7Q3Ycx2kM/Am5IEXiulasKWteZxkC7ZY4kDW3GOx4sRVXnrDffmb+rq55gc2KIZcTK84ivnMXaVdsH8iONcb2Ibv95+61T2A79PLLCtQfj1VGW5sVV62t/1514IGmffcLLojKX2u/6t9+mY075AagDGfsOE7z0cyzLFrGITuO40BzPyHnqb0NE5GficgVIrJb1Tl7TSsuLuQ4zsDRyuJCl5LoH18P7CsiXwN2U9X5wEZZmSrFhcieyOo4jlN3GtHRxtLrwhARmamqIyv+PwbYBtgemKaqG0TUUaNDDrPXKi5UK1m7WPx8190C2/Allghs+19yiZl/t43C77gvrRfuk7jPxIl5TWx5+nsn41qYcvDBgW3n884JbAM9DtKfu5n3Qs3e9KoDD4z2ObtfcEFDee+8J+QhItKmae9X1ZNFZDbwR8BebuY4jjOAtGwMGfg98MVKg6r+Cvg+YG825jiOM4A0844hefKb5g6Xqnq7iJxSTpMcx3H6TuO52Xj6LC4kIv9U1Y/lpevunh9UYAmrPHSRLQS+gSEaXkRcqL09FJzJ3sk4tFs76WYtFrC+ca3FHnt/IVxAAnD1Aw8Etl3GjAlsk+++1swfKzCevbtweK+s9qepjXLDH00dHfau07aQjfVRyhIHCvuQtYgmSxzI6oNFdvi246r2D05rh2lLiKdIv2prC9uaHUO3xmFCW2fn22buwYOXCWzWe50lTmS9rx0ddsRTxFgxU5DrDj002ql9/dxzG8p/9/qELCKPZp0CVqx/cxzHcWqjEUMRseQN6q0IbAW8WWUX4L5SWuQ4jlMDzeuO8x3yzcCSqjqz+oSI3FNGgxzHcWqhZZ+QVdVWwUnOhZNuDay4pGUbc/DRMcUB8OQfLjbt62zxraj8Rd6vWHHu7PxhXNOaWwx2vHjKgw8Gtu+OHWvm/9wnPxnY9jj/52baQYOGm/ZqrPZnp42/V1kxxFji5yFnjS3E1ZM1NlGM8L7YAvn9N9/X+gxk9wkrhh33uR4I6umQRWQccC7QDlysqqdWnd8dOCr99x3gAFX9a0xei8acSd8HYp2xE++MHacZqde0N0m+IS8AtgbWBXYVkXWrkv0D2FRVPwucRLpCOTJvQMs4ZMdxHEgE6mOPHEYDz6jqc6q6AJgC7FCZQFXvU9WeMbYHgBGxeS3yxIXGVbweLiKXiMijInK1iGTOsnBxIcdxBooiT8iVvio9xlcUtQrwYsX/s1NbFvsBt/UxL5A/qHcKcHv6+kzgJeDLwFeBicCOViYXF3IcZ6AoEkOu8lVBUVaWjDo3J3HImxTNW0kRPeRRFUJDZ4vIN+OyhW0osjtxV9d7gW3NzXYxyrRXchcZlLEGWoosDLEHP8LryhIH2nne84HNGsC78M47zfwHfPGLgW3qJlsGtikPhgtQEuJ2kgZ7AMq6V+3ti5v57QUX1g82exGPtbDDWoSSJZgTuzAka7GDdf1Zg3L2zulh/UV2J6l1sNn6XGYtAuroWCrMHbmIKiHsV7W3P5s6zrGYDaxa8f8IYE5Qn8hngYuBrVX19SJ5q8nrASuIyOEk1zhMRET/u8TH48+O4zQcdZxlMR1YU0TWAP4F7AJU68J/DPgtsKeq/q1IXos8h/xLoOfr8VfA8sCrIrISMDOvcMdxnP6mXnrIqtopIgcBU0mmrk1W1Vkisn96fgLwY2A54ML0i6BTVUdl5c2rM28e8gkZ9pdF5O4C1+Y4jtMv1HMesqreCtxaZZtQ8fpbgDnn1sqbRy176p1AsqNIDlZcNT6uW2RhQq1YMcD+rN8SB7IWe1ixYoCL7rorsJ29xx6B7ajtvmzmP/SCsF99dLUdzbQWWfFii1oXEcTmz4rLtrfHdf0isc4/n3qSad/46OOiy+g/LMGfMFacmdtYhFPk/S+Tll2p5+JCjuM0Gy3rkHFxIcdxmozmdccuLuQ4TovRzE/IfRaoL0BQwazbwnnYn956fGBzYOHCtwLbnsbcYoAxa64Z2L535ZVm2m9uvHFgG2XEq/e54Edm/iWX/JRpd5waqdmb3nvccdFObdOTTmoo713LoJ7TpFjO2HFahWZ+QnaH7DhOS9HMDjlPXGiUiNwtIleKyKoiMk1E3hKR6SKyfi/5XFzIcZyBQQocDUbeE/KFwPHA0iSzKr6nql8SkbHpuf+1Mrm4kOM4A0bt+6QOGL0O6onII6q6fvp6kV2mK8/lUJNDtkRYiuw6XSv9ubODJc5iL4Cwb6m14OPlN6tnLMKv/vxnM/9hW20V2JZeYgkz7U9++9uwVYXulXUN8QJZljhQlkCVWXv0jiNZ3beIolhYV63iQtk7TPedWgW6soSc7LZm3r+avemfTj4h2udscszxDeW983rA+yKyJTAcUBHZUVVvFJFNAfvuO47jDCDNHEPOc8j7Az8n0UDcCjhARC4jUS/6drlNcxzH6QNNHLLo9XePqv5VVbdS1a1V9SlVPVRVl1bVTwM+EdVxnIZDJP5oNPq8MKQ6ptwLPqjXT8x54cbAdv2Pbwhsz778spn/nKlTA9vlBxxg1/XGG4Ht6GuvzWmh08NDl50V2Dbc+/ABaEnDUbObvO/0n0b7nM8feWxDuWUXF3Icp6Vo5Riyiws5jtNctLeuQ3ZxIcdxmopmfkIeEHEha17mPT8+3sy86QmhuLeVH+z5kkXma95/xsmBbczhPzDqKbLiPL5zWJtMWgL5WRuPWvNw33nnaTPtGXv9MLB9fMUwCrXXRReZ+S/ad9/Att+Enwe2wYOXN/Pbc5bD9+ovE043848af2hgKyJ6X2QecK1zhq389jzqrPrDz2iROdexZG3oWkSk3+qbloPs5f7V7E0fOO9n0U5to0N+2FDeu2W0LMpaGNKKWM7YsbGcqWOT9aDQ3zTzE3Kfl/uIyG31bIjjOE5daFUtCxHZIOsUMLKXfOOB8QATJ05k/HjXOnYcp3+QJl4YkheymA7ci/1dsnRWJhcXchxnwGhih5wnLvQ48BVV/btx7kVVXTWijpocsh2XsgY5ml9cyBpUsQZUsgY1axWcOXXnnQPb8AxxoQMmTw5sX9tww8D2yZVXNvOfdvPvDWu8uJD9vhQZaIsTF6rPvQ6vwR78ak1xIet97aX9NXvT6RNPj/Y5n/vOkQ3lvfN6wE/IjjMfXN+mOI7j1IEmfkLO07L4DSAiMlZElqw6bc+RcRzHGUCaWcsib8eQQ4DfkTwNPy4iO1ScPqXMhjmO4/SJJvbIeSGLbwMbquo7IrI68BsRWV1VzyU61mPFz8JYVVacyo5r2XE9C7vcrLhkWK4lGp8d6wtvSZEJ/Na1WjFkq01J2sWj6smK9VniQAsWvGamteLF1z/0UGD77tixZv5dxmwU2K78812BLWuxhLWwoqMjjO1nXauV377XdlzVGkfIiota8WJ7sUl8vy6yWCOWIjFk+7Ni57c+F2W0/4PaWnjpdLuqvgOgqs+LyGYkTnk1GnIWn+M4H3oa8Mk3lryh2pdFZGTPP6lz3g5YHvhMie1yHMfpE/WMWIjIOBF5WkSeEZGjjfNri8j9IjJfRI6oOve8iDwmIjNFZEZM2/OekPcCFvltpclvrb1EZGJMBY7jOP1KnZ6QJYlLXQB8CZgNTBeRm1T1iYpkbwCHADtmFLO5qtpxP4NeHbKqzu7lnL1TZoAVV40XgenoCMV1sueGxs4Ptt8wK78l7lMW1rVaxMaKsygyjzpLHMiaX2zFiy+8804z/7Mzrgps5+xl7wp2xNVXBzYRKzYfL2LT3h4357cesc5rDw2FkHb9xbmBraz57bF0dCzZy1ziRbHi5bX2y3pRxynao4FnVPU5ABGZAuwAfOCQVfUV4BUR2bYeFRZuuoisUI+K681Ad2anHCxn7JRDrDNueOoXs1gFeLHi/9mpLRYF7hCRh1I5iVzytCyWrTYBfxGR9UlW+YX7+DiO4wwgRbQsKnV3Uial0g9QZOmozcaqOid9iJ0mIk+p6h97y5D3u+014IUq2yrAw2nDPm5lcnEhx3EGjAIx5CrdnWpmA5XyECOAOQXKnpP+fUVEbiAJgdTkkH8AbAEcqaqPAYjIP1R1jZyGuLiQ4zgDQh1jyNOBNUVkDeBfwC7AblFtEFkCaFPVt9PXWwIn5uXLG9Q7Iw1kny0iLwLHUwcHW2QHBnsXg3hxoSKCLbGLWLLj1WFPKCICY11rR0f1ivXshSHxg6VZb2H4ZJEVV7TEgazFHtbgHcAnRu0e2P79Vhgvnj//VTO/dV+sQaWsAWCrD1p9KHthSNivsoTRdznvTKNca8FRVr+yBsZrE9O6zhho/OpZPzPTxg5sZ/XLMtrfK3WaZaGqnSJyEDAVaAcmq+osEdk/PT9BRFYCZgDDgG4ROQxYl2Rq8A1pn+gArlbV2/PqzPVW6UyLr4vIl4FpQP9NO3AcxylKHReGqOqtwK1VtgkVr18mCWVUMxdYr2h9uQ5ZRNYmiRvfDfwB+ERqHxfj8R3HcfqTJl6oV0xcCNhSVR9PT7u4kOM4DYe0SfTRaPSDuJBFkWkpsaLh9bi5VhmhbfJ3vmvm3m9SrYsXY6+h1lGL2u5/VhmWOFDWYg8rXnz6LbcEtoO+9CUz/5k3h7FpK4ac1X7VuHuQld+OF8cvOLI3hqi/6HwW2/7se4a11s9Q/7W/V5r4EdnFhRzHaS0a5HuhL7i4kOM4LYWIRB+NRp5D3gt4udKgqp2quhfwhdJa5TiO01ekwNFg9IO4kOM4Tv8hbc0bsyiyaqJu2Dv+2pPirbT2oF49FgSGZVh17TvxQjt3jTshx19D/M4SxeoJHxn+MuF0M+Xo/Y8MbNbuHlniQNaCD2sA7/xp08z8qxk/Nzf7/OcD22V/utfMH3uvsxaW2MpyWQOIVh+2+kr/PbINHbp6YOvqmldjqfH3qlQa8Mk3lgFxyI7jOGXRiLHhWPLmIT8sIseKyCeKFCoi40VkhojMmDQpS7fDcRynBFo1hgwsAywN3C0iLwPXANf2qBhl4eJCjuMMFM38hJznkN9U1SOAI0Tk/wG7Ag+LyJPANRW6oZlYsTJL2EU1K/5mpbVEgMKdfQHazAB/1q7TcW3tzgiV2R0hfoDBvlZrd98wHcSLC2XnD3fhGDU+FKFJ2hXeK6tce2cPWxzIWuxhxYoBXjD6wPYjRwa2BQtsyW6rftXw42C9J1lkxZtjd53O6lcW1n3NdkSWaFR8v7r/lBMC28bHHhOd325XmeJC5RVdNtHeQlX/T1W/S6JrcRrwv6W1ynEcp4+08tLpv1UbNHk0uj09HMdxGosmDln0+oSsqruk21yPFZFFfuOJyLhym+Y4jlOcZl6pJ7bISXpS5GDgIOBJYCRwqKr+Lj33sKpuEFFHTYN6tkB6WGSx+b5F6rfmTJcz8dwSQ7eFvOPnEfcvVrtqa9M3N97YtL/57ruB7aaZMwPbGbvZGzy8tyC818f95jfFGlcD9iYN8QL1ZZAVL4//bNWlX9Z8sU/fe2m0z/nUpvsM9IdmEfLu9HhKUXtzHMcpiQZ88o3F1d4cx2kpmtgfu9qb4zitRTPPsnC1N8dxWguR+KPBGBC1tyIDZfGDerUPiJy9xx6B7aBLQyGhQYOWii6zSP32oEo4qJe1E3RZA5uxFBFXiu0DWeJA1oIPawAvS9zIyn/+3nsHtgMvvcTMn93fQuxFINau11klWEJG8fVPOfjgwLb9ad8PbIsttrJde2S/yuqXtba/MI3nZ6NxcSHHcVqKRpzOFkueuNBwETlVRJ4SkdfT48nUtnQv+VxcyHGcgaGFxYV+DdwFbKaqLwOIyErAN4HrAHMHShcXchxnoGjEwbpY8haGPK2qnyp6roqggvgFENDV9V5YoBF/zMpvCeZkYQuuzA9sWTE166fSuXvtE9gOu+IKM//ChXMD26BBwwJbV9f7Zv729sVMezVZIjj2Dt9Z4jphWut9zWqTldZ+ZMkSggrbdcrOewW2o662f6ENHrxsYLPu/7c229rMP+muG40ylzPTqobiQrYQkx1XtfpVrJBUETo73zHtlhCT9b50dYWfFSjc/pq96bMzrop+CPzEqN0bynvnzbJ4QUR+ICIr9hhEZEUROQp4sdymOY7jFKeJJ1nkOuSdgeWAe0XkTRF5A7gHWBb4RsltcxzHKU4Te+S8aW9visj1wG9UdbqIfBoYBzypqrbQrOM4zkDSeH42mrwY8vHA1iSOexowGrgX2AKYqqonR9QRFc+5+chw00yAbX9+WnSR5cxtDOu6bP8DzJR7T5gQ5jbn5trtLBJb/7ATO4/ZmlsM8J1fnhfYrHh9Z+fbZv57jj81sG1xcszHIaFIv4hlwYLXTPvcuX8NbH844cbAtvN5Z5v5+3l+e83u9B8zr4mOIa8xcteGct95IYudgI1JVuUdCHxFVU8EtiIJZzQMpU40dxynaajn0mkRGSciT4vIMyJytHF+bRG5X0Tmi8gRRfJa5DnkTlXtUtV5wLOqOhdAVd+j9r3oHcdx6k+d5iFL8pR3AUmUYF1gVxFZtyrZG8AhwBl9yBuQ55AXiMjQ9PWGFZUNxx2y4ziNSP0G9UYDz6jqc6q6AJgC7FCZQFVfUdXpQPXcxty8FnkO+Qvp0zG6aNBuEMniEMdxnIaiiD+uXFWcHuMrilqFRaf3zk5tMfQpb6+DenXCWBgSTpTPWsDR1TUvLNBoc3kLQ8KBtiIT+IsMiFgDSB0doZBRrQtDiuzs0J8LQ+yFKfYPMatd7e1DjXS24M3em4RihZfcG24Tad1/sHe4XmfECDPtz276bWCzPwPxfaWchSHhLiwAHR1LGNbGXRjywhO/jnZqq637jcz6ROTrwFaq+q30/z2B0aoaqDWJyE+Ad1T1jKJ5K3FxIcdxWoo6Ti+eDaxa8f8IYE6ZefPEhVYSkYtE5AIRWU5EfiIij4nIr0XE1urDxYUcxxlA2iT+6J3pwJoisoaIDAZ2AW6KbEWf8uY9IV8G3AIsAdwNXAVsSxKcnkBGkNrFhRzHGSikTitDVLVTRA4CpgLtwGRVnSUi+6fnJ6RiazOAYUC3iBwGrKuqc628uW3PWRjyiKqun77+p6p+rOLcTFUdGXNd4YXGC9RbscYi4kK17hBtxfqyY8i11WXFy624qNUmKBYvL4MiYwN2H7Bi2Fkx5DA2XGQRzfz5rwa2/zvxnMB23i23mPmtHa4P+OIXzbRDBoX34KzbwoelInPpy1isYQl5AbS3Lx6VP3u8IaSX9tfsTV/822+iHwJXXWunhloYkveuVnqYy3s55ziO0xg0oEZFLHkO+XcisqSqvqOqx/YYReSTwN/KbZrjOE5xmtgf54oL/VhERouIpuJC65KICz2lqjv1TxMdx3HiaWWB+mpxoTEk8ps1igtZddo30Y6X2nFFO4ZY25tjxcXKElux5hdb83iz4qr3nRa+HRsffVztDYukmGBObB+w+6e9SWhtc86tMYCjttvOzD93Xhjvv+iuu8y08+aF0uE/2GHfwPaLO8J50GnLQkuN4xUW1mYMED/nOXuT05Be4uU1e9N//eOG6BjyKmt8paG8d55n2QkYCQwBXgZGpKOHpwMPAvHyViXjqmiO4wBNLb+Z55A7Nfnamycii4gLiYhrWTiO03C07K7TuLiQ4zhOv5H3hPwFVZ0PLi7kOE6T0MRPyP0gLhRWYC8gaNRdp+svLjTlYFtfZKezTwlsZYgL1WPXaeseWOIy1x56RGAD2OW8M6PKzGqr1YcsEZys9tsLS8LBK3t3bDh86y8HtlNuuNhMO3ToqoHt/fdDWYOsfjVo0NKBzbqvXzvzBDO/tcO2RVm7Tlv00ldr9qYvzb4p2qmtPGL7hvLehYdqRcTe69xxHKcRqJNA/UCQJy50qogsn74eJSLPAQ+KyAsismkv+VxcyHGcAUFEoo9GIy+GvK2q9uwFdTqwc7pAZC3gamCUlWlRcaHyBZcdx3E+oPH8bDR5C0OeAv4nVT16QFU3qjj3mKp+Jq+C7u6FhrhQfAy5szOcgG/Fr7JjyJbdvmYrrmjHkLO+xywh7vgY9sKFbwW2QYOGB7YsERg71heSFRd95PLzA9vIPb9rprXizbZAvS1MYy9CsGLY9mIDKzZsxZCtBSRZ5Vpxzax4vSUmf8hW9iKSM37/q8C22GIfDWxj117bzH/zI3ca+UP126wnPvszHsbms2LIVgzbiuF3d9v3ynpfOzpC0ayE2h9b//3yzdEPgSuutF1Due+8J+QLgFtF5FTgdhE5B/gtMBaYWW7THMdx+kADhiJiydOy+IWIPAYcAKyVpl8LuBH4aemtcxzHKUgjxoZjiRFlmAeckcaOP00iLjRbrbiD4zjOgNO8DrmouNBo4F4KiQvFzkPO2uTUmodst9mKF9c+DzmMdRYRqC8iRLRw4dzANmjQsMBW+yan8eI65W1yGhdDzloQasWGrbhkVgzayl+krXYfsB3BggWvBbZt1gsnKd351FNm/s+vvnpgG/2pTwW2c6ZONfPHUmQecuxnJSG8L2XOQ37llanRMeQVVtiqoby3iwt9CMlabOGEZDsZp1Fp5ZCFiws5jtNktK5DXiAiQ1V1Hi4u5DhOE1Bkb8JGw8WFHMdpMVr0CbnHGRv214BwpCIa64Zl3cTYXSSaf0FgI8a+sgZQa29qGdca36/Kudd2mdbCCmuxhzV4B3Df888Htm9vGg4KHj5unJn/rNuzdiKpN/15r3tpRQm7qfQX5exF5DiOM2A03oNNLHniQkuKyIkiMktE3hKRV0XkARHZOyefiws5jjNANK/cW94T8lXADcBWwDeAJYApwLEispaq/sjK5OJCjuMMFM0csshbGPJXVV2v4v/pqvo5Sa74CVW11VAWpSaH/MhV5wS2kbsdEtia+U3oociCmQ87sQtbGoFrjA0Jdjnv3MD2vXFbm/nffT9cCPTLe+8NbGftvruZ/3tXhuJG1oKlIhsXlEjNj61vvHF/tM9Zdtn/bajH5Lw7/a6IbAIgIl8G3oAPZlw01IU4juNAa+shHwD8MtU/fhzYF0BEPkKiBOc4jtNgNOYvpRh6bbmq/hU4GPiSqm4CdIjI4cDnVPW8/mig4zhOMeo3qCci40TkaRF5RkSONs6LiJyXnn9URDaoOPe8iDwmIjNFZEZUywuKC40B7qGQuFD9JwgPdKyrs/Nd024JpF914IGBbfcL7B8XlhBOM686KherW9X2E9QSAQIYPHj5Gst9wyhzGSOl3X5rfvGI5cKtLQ+/6ioz/9UHHRTY/vNu2IcPmHyJmb/ZYsj/+c+MaJ+z9NKjMuuT5MP3N+BLwGxgOrCrqj5RkWYbkofWbUj847mqOiY99zwwKl23EUXLiAs16oCO09zU6oyd/qeODzGjgWdU9bmkXJkC7AA8UZFmB+ByTZ5sHxCRpUVkZVV9qS8V5nmxTlXtSrUsFhEXwrUsHMdpSOoWslgFeLHi/9mpLTaNAneIyEMiMj6m5S4u5DhOixEf9UgdZaWznJSuo8gqqDoc0luajVV1joisAEwTkadU9Y+9tcfFhRzHaSmKTGdbdBFbwGxg1Yr/RwBzYtOoas/fV0TkBpIQSK8OuddBvTpRQgX1H9DJ4vRddw1s37/qcjNtrYs4rN05iuw44oRMMRZlAOzyi18EttdeCwV/ll9+bN3bBLUvbLH6yjUHH2am3e38cDfxCfvtF9hGfX5dM/8aO44KbMstF4ob1YmaP8hz5z4e7XOGDfuf3gb1OkgG9cYC/yIZ1NtNVWdVpNkWOIj/Duqdp6qjRWQJoE1V305fTwNOVNVelZ780+44TotRn4czVe0UkYOAqUA7MFlVZ4nI/un5CcCtJM74GZL9R/dJs68I3JA+rXcAV+c5456EmYjIMOCHJI/ht6nq1RXnLlTV72bk+yAuM3HiRMaPj4pnO47j1Ew9Z1yp6q0kTrfSNqHitQLB3NZ0ZsZ61fY88p6QLwX+DlwP7CsiXyN5ZJ8PbJSVqSou4+JCjuP0I423JDqWvIUhM1V1ZMX/x5A8nm8PTFPVDbLyVlCTQ7YXgYRFlreAIqzr7bft3YGXWmqdmmqKFxfKuqXN0xGvO/TQwLbtz74X2IYOXd3MHxuDnTfveTO/Va4Vb7ZizfWgjPGCC/fZx7S3tYX3Zf9LwkUg1gIWgP874czANvbk0pYg1NyJ33nn6Wifs+SSn2qoD01eDxgiIm09MyxU9WQRmU0yUhjuDe44jjPgNO8isbyW/x74YqVBVX8FfB9YUFajHMdx+k6LCtSr6g+qbSJyuaruBaxZWqscx3H6SCPKasaSF0O+qdoEbA7cBaCq2+dVoNoZVGDFSsGOl3Z2zjNS2osE29oGG7YhVqvM/FZcsrs73Oc1O9YXdoQic5MXLnwrsA0aNNxM29UV3peOjqWi6+ruDn/gWPevqysURwc7XmuV2d6+eEb91v65sRvaQnd3GIPt6Bga5jYEm5L8YR9sb1/MTNvVZfWB8PqzRvetumqNIbe3h9eaxUOTzw5sn90zjDcPHrysmf+bG28c2P756qtm2t/NuCGwWfF6S4grpWZv+u67/4iOIS+xxBoN5b3zesCqwCzgYpJPhgCjgDDKXyO1LqqwnEmrYjnjIliO07GxnLFjYznjgaCZn5DzYsgbAg8BxwBvqeo9wHuqeq+qhnvIOI7jDDitG0PuBs4WkevSv//Oy+M4jjOwNO8siyjnqqqzga+n67bnltskx3GcvtPMIYsmFRdqVfpPNKnZ+dNJJwS2TY47fgBa0vi8/noYXZx51h2B7fJ77jHz/+rPfw5sx+24Y2Dr6rYH20+5qXpuQK/U3OHff39OtM9ZbLGPNtQHzMMPjuO0GA3lYwvRa7BFRMZVvB4uIpekG/ldLSIr9pJvvIjMEJEZkyZlSY06juOUQYsO6gGnAD2ScWcCLwFfBr4KTAR2tDK5uJDjOANFy8aQReThHgEhQ2hokf97oUZxoXACvbWAo6x5yP25E7S1WMJa2DLQu25nUavoerG6ahXniY3X10PIKSzDWthSpF+VcV83X2st077JuqFw/Uk33hjYdtvIFoDcauTIwPbNCRPChAk1e9P581+J9jlDhqzQUN47rwevICKHk9ykYSIi+l8P3rxzSxzHaWEayscWIs8h/xLoWY/7K2B54FURWQmYWWK7HMdx+kR5Urzlk7cwJJhbVCEutFdprXIcx+kzLfqEbIgLAXxRRJaGOHEhx3Gc/qSZB/X6Ii70OQqJC1kDGtbOGPagnCWEkzWoZWGVm53fams40PbKnGlm7hVX2caoP140yVJWswb1bKW0bGW1arIU0KyfelnKfOWovVnDEvZ7Zau9he3PVnsL81tqb1niQm1t1kfHdgS22ltoK6IimKVMVwtZ4kCn7vHDwGYN4F39wANm/jffDO2HjxtnpISzbs/dBzSC5nXILi7kOE6L0aLzkF1cyHGcZqOVQxaAiws5jtNMNO+M3EJPu6p6C3BLkTxf2WDDwHb9jAeLFGFQ2zdg1qR6O7Ycpl1pxLY11Z9N7HWVc/122qy6LHuRdtV2rXa74tsU+xRV3tOW9R7U1tZaydrh2xIH+tX++wc2K1YMsMwyYby5M0OIqB60/BOy4zhO89C8DjlPXGiUiNwtIleKyKoiMk1E3hKR6SKyfn810nEcJ57mHdTL++16IfBzkjDFfcBEVR0OHJ2eM6lUe3v+tdfq1ljHcZw8RCT6aDTyxIUeUdX109f/VNWPWed6Y5cxY4IK2owbceZ14VxHgJVX3SGvipahP8V5nIGm9TYjyJpbbMWLz7sjFMhPqfkmdHcvjBYXamsb1FA3PS+G/L6IbAkMB1REdlTVG0VkU8Cece84jjOANOKTbyx5j18HAN8H9gW2AjYXkTdJwhWHltw2x3GcPlC/GLKIjBORp0XkGRE52jgvInJeev5REdkgNq9F3sKQmSSOuIdDRWRZVd0zpnDHcZz+pz5PyJJoCVwAfAmYDUwXkZtU9YmKZFsDa6bHGOAiYExk3oC+igvdBC4u5DhO41HHkMVo4BlVfS4tdwqwA1DpVHcALk914h8QkaVFZGVg9Yi8IaqaeQCPAFcCmwGbpn9fSl9v2lvejPLG1zttGWU2U/3N1NaBrr+Z2jrQ9TdCW/vjAMYDMyqO8RXndgIurvh/T+D8qvw3A5tU/H8nMComr3X0t7jQ+BLSllFmM9VfJO2Hvf4iaT/s9RdJW1b9paOqk1R1VMVRuStzzJ5eWWmK7Af2AS4u5DiOYzObRIK4hxHAnMg0gyPyBkRNclXV2ar6deA2khCG4zhOqzMdWFNE1hCRwcAuQPW42k3AXulsi41IIgkvReYNKF1cqIpJ+UkKpy2jzGaqv0jaD3v9RdJ+2Osvkras+gcUVe0UkYOAqUA7MFlVZ4nI/un5CcCtwDbAM8A8YJ/e8ubV2etKPcdxHKf/8HW5juM4DYI7ZMdxnAbBHbLjOE6DUKpDFpG1ReSodK33uenrdTLSjRWRJavstnzUomkuz7CPEZFh6evFReQEEfm9iJwmIsMr0g0Wkb1EZIv0/91E5HwROVBE4reMdnIRkRUKpF2uzLbUk9jrasVrStM2zXU1OqU5ZBE5CphCMkH6LyTTQAS4plJoQ0QOAX4HHAw8LiKVepunVJV5U9Xxe+CrPf9XNWEyyagnwLkkinWnpbZLK9JdCmxLotNxBfB14EHgc8DFfb4BNdKfHwgRGS4ip4rIUyLyeno8mdqWrkg3TER+JiJXiMhuVWVcWPX/slXHcsBfRGQZEVm2Ku2pIrJ8+nqUiDwHPCgiL6TKgpVpozZNiL2msq6rjGtqpveqyDU5FZS4JPFvwCDDPhj4e8X/jwFLpq9XJ1m+eGj6/yNVeR8mcik38GRlvqpzMyteP5r+7QD+DbSn/0vPuaq8w4FTgaeA19PjydS2dEW6YcDPgCuA3arKuLDq/2WrjuWA54FlgGWr0p4KLJ++HgU8RzLl5oXKe5Ceuzu9X6sC04C3SL4Y168qcypwFLBShW2l1DatwnZ9Wv+OJHMqrweGZNzjbuAfVcfC9O9zVWkfq3h9N/C59PVawIyqtH8hEXTZFXgR2Cm1jwXuL3pNZV1XGdfUTO9VkWvyo+L+llZw4rBWM+yrAU9X/P9E1fklgduBs6hwnOm5NuB7JM5lZGp7LqP+64B90teXAqMqOs70inSPk3xJLAO8TeoAgcWocOoV6VvuA1H5fhjXW/leVb8fxwB/JvkCqb6mI9L38TMVtn/00lc60tcPZF1v+v8jFa//2cu5qGsq67rKuKZmeq+KXJMfFfemtIJhHMmT220kk8EnpW/6M8C4inR3kTrXClsHcDnQlVH2CBKHe371m12RZjhwGfAsSQhiIcnT5L3AehXpvpfaXwAOIREH+SXJk/vxRrkt94EA7gB+AKxYYVuR5EvmDxW2J4G2qrzfBGYBL/TyPp0FLEX2l+fBaRu+CPwEOAf4AnACcEVV2vuBLUlCSy8AO6b2TVn0Cynqmsq6rjKuqZneqyLX5EfF/S218OSJdiPgayTqRxuRhgSqOsJKGfk3zil/W+CUnDRLAeuRCCWtmJHmo8BH09dLp20dnZG25T4QJL8OTiNx9m8Cb6TtP42KkAnJ/opbGG0aR0UYyjj/ZeAB4OVe0mwGXEuiMPgYyQqo8VSFvdL3cirJF/3aJOMD/0nv6+eLXlOZ11XjNb2ZXtPGVWmrr+vN9Lp+Xqf3avuI92pz47q+U3ldwMjYa/Kj4t4OdAOa7aj6QLxR9UFfpiLdQDivjoo0UY6rIv3awBak8fzK9hrpxhrpts4ocyxJGGpx4H+sMnPKtdKuE5OWRM+2J6TzaZLdb7bJuKeVadcFDo9M+xngWCttwfrHxKY18l4Rme7yyHSLA9cV+EzElhvVzg/z4Uun64iI7KOql9aaTkQWBz6hqo/HlllL/ZLMdDmQ5ItlJMmg6u/Scw+r6gbp64OBg/LSFSmzj2m/S/KF2FtbjyeJoXeQjDmMJglXbQFMVdWTK8qsTjsGuCcyrVlujfX3ltbcNIIk9Iemm0YY6YTkyXaRdEXKrLH+zDKdCgb6G6GVDjLi2X1NV1ba6nREznSJTdcIadN07cBQYC4wLLUvTtXsmTLSllh/1EwjCmwuEVtmWfX78d/DtY0LIiKPZp0iiSUXSldW2iJlksT13wFQ1edFZDPgNyKyGosKbcema4S0naraBcwTkWdVdW6a5z0Rqd6Xvoy0ZdU/imSD4WOAI1V1poi8p+GGERtGpitSZln1OynukIuzIsnGr29W2QW4rw/pykpbpMyXRWSkJpvaoqrviMh2JItrPtOHdI2QdoGIDFXVeSTOIbn4ZJVmtZMrI20p9WvkphGx6cpKW6RMp4KBfkRvtgO4hIo9tKrOXV00XVlpC5YZNdMlNl0jpCWd822kWZ6K6YVlpS2rfiNN7kyjIunKSlukzA/z4YN6juM4DYKrvTmO4zQI7pAdx3EaBHfIjuM4DYI7ZMdxnAbBHbLjOE6D8P8B0FVodkK6CiQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.spatial.distance import squareform\n",
    "idx = 10\n",
    "plt.figure()\n",
    "sns.heatmap(squareform(w_pred[idx,:].detach().cpu().numpy()), cmap = 'pink_r')\n",
    "plt.title('prediction')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAAELCAYAAADuufyvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAqTUlEQVR4nO3dedwd4/3/8dc7iT0klpLW3qLKtwQpWi2CEEvtWqIoaW+7qqX8SotaGl9q6ZdIUhF7tPaUEKm1rapYYg8lqBS1L6klIp/fHzO3jjNzzszcZ5s59+f5eMwj51xzXTPX5JbPPa65rs/IzHDOOdd+fdrdAeeccwEPyM45VxAekJ1zriA8IDvnXEF4QHbOuYLwgOyccwXhAdkVgqS7JP2oyefYVNKsZp7DuXp4QHaFI+mHkv7SgOOYpFUa0SfnWsEDskskqV+7+1CLpL7t7oNzjeYBuZeRtK6khyW9L+lqSb+XdEr3/85LOkbSq8AESQtIOkfSy+F2jqQFwuPE7mKjd6SSLpZ0vqSbw3P9XdJXInWHSZoh6V1J5wEKy78GjAG+KWm2pHcix7tA0mRJ/wGGVg5zRPsk6Z6w+JHwON+P1DtS0muSXpG0b8P/kp3rIQ/IvYik+YHrgYuBJYCJwE6RKoPC8hWBLuA4YENgMLA2sD5wfI5T7gGcBCwOPAucGvZjKeDa8FhLAc8BGwGY2VPAAcDfzKy/mQ2MHG9EeIxFgZpDGma2cfhx7fA4v49c4wBgWWAkcL6kxXNck3NN4wG5d9kQ6Af81sw+MbPrgPsj++cBJ5jZx2b2IbAn8Csze83MXicIrnvlON91Zna/mc0FriAI7ADbAE+a2TVm9glwDvBqhuPdaGZ/NbN5ZvZRjn5EfUJwTZ+Y2WRgNvDVHh7LuYbygNy7fAn4l30+o9RLkc+vVwS6LwEvRr6/GJZlFQ2yHwD9I8f97Lxhf6L9qCZLnTRvhr8gkvrlXFt5QO5dXgGWlaRI2fKRz5Wp/14mGL7otkJYBvAfYOHuHZIG5ezHZ+cN+1OrH9XKP9cHguEI50rLA3Lv8jfgU+AQSf0k7UAwLlzNROB4SV8Ix31/CVwe7nsEWFPSYEkLAifm6MfNYdudw9kch/H5YPpvYLlwzLuW6cDOkhYOHyaOrNj/b+DLOfrlXFt5QO5FzGwOsDNB4HoH+AFwE/BxlSanAA8AjwKPAQ+FZZjZM8CvgD8B/yDlIVtFP94AdgNGAW8CqwJ/jVS5A3gCeFXSGzUOdTYwhyDwXkIwTh11InCJpHckfS9r/5xrF3mC+t5N0t+BMWY2od19ca638zvkXkbSJpIGhUMW+wBrAbe2u1/OuWAKlOtdvgr8gWBmwXPArmb2Snu75JwDH7JwzrnC8CEL55wrCA/IzjlXEKljyJJWB3YgWPtvBAsDJoU5B7KIjYmcsNNOsUqvvftuYuML7rgj42mcc0W0y3rrxcoeeOihxLovmilxRw4rSpnHYRtxvkaqGZAlHUOQIOYq/pvzYDlgoqSrzGxUk/vnnHO5zDd/2nqi4kq7Qx4JrBkmgPmMpLMIJu4nBmRJXQTZwhg7dixdXV0N6KpzzqWbb4EF2t2FHksLyPOIJ5gB+GK4L5GZjQPGAZyw0052wi23fG7/Sddfn9juvB/+MKU7Lq9z9kpOznb4ZZe1uCeut7r2wQdjZb/de++mna+T75APB26X9A/+m2lrBWAV4JBGdsSDcet4MHadbMGFF06vVFA1A7KZ3SppNYIENMsSvNVhFjDNzD5tQf+ccy6Xfn3KO3ksdZaFmc0D7mtBX5xzrm59+5b3dYu+dNo511H6lTggN33p9IGbbRY7wZorrBCrd8jFFze1H2Vw8q67xsp+cc01bejJfx297baJ5WfcfHOLe9J8v9hxx1jZyTfc0PJ+9HJ1zwvefPXVMwe122fMKM88ZOecK5v5+5U3rJW35845l6BvJz/Uc865MulT4oDcivSbDT/BmSNGxMre//DDxLrVFqE4V2YHbb55rGz07be3oScNV/eY7k7rrps55lz/0EPlGkOWtD7Bm9qnSVoDGA7MMLPJTe+dc87l1EeFirG5pCUXOgHYGugnaSqwAXAXcKykdczs1OZ30TnnsivzGHJaz3cFNgI2Bg4GdjSzXwFbAd+v1khSl6QHJD0wbty4hnXWOefS9OvbN/NWNGlDFnPDJdIfSHrOzN4DMLMPJWVKLkQTxpCdc66aMj/USwvIcyQtbGYfAJ9lmZY0gBrZ3hop6QHeUVdembn9ignjSS+W6D2C40aOjJV1jR/fhp64IumQB3hN0bdTx5CBjc3sY/gsp0W3+YB9mtYr55zroY69Q+4OxgnlbwBvNKVHzjlXh46dZeGcc2VTxId1WRU+IFdb8JHV8E03jZUdMmxYYt3zpk6t61zN4OPFzuXjd8jOOVcQ8oDsnHPF0LEP9ZxzrmzKPGRRyuRCSXOLk8aKAcbeeWes7Iw99kisu/xSS8XKdv+//8vXOedcPeqOpodvtVXmmHPOlCmFit5puSw2AJ4ys/ckLQQcC6wLPAmcZmbvtqCPzjmXWZnfqZc22HIR8EH4+VxgAHB6WDahif1yzrke6SNl3oomLSD3MbO54echZna4mf3FzE4CvlytkScXcs61S5kDctpDvccl7WtmE4BHJA0xswckrQZ8Uq2RJxdyzrVLI6e9SRpOMDrQF7jQzEZV7N8TOCb8Ohs40MweCfe9ALwPfEqQqG1I6vlqPdQLkwidC3yHYKn0usBL4XZY94lTtCQgV1vssWLCg7qjJ05MrDsmIZHPP19/PVZ22qRJOXvnnMuo7mj68+23zxxzTps0qer5JPUFngGGAbOAacAeZvZkpM63CJ6zvS1pa+BEM9sg3PcCwchC5jQTabks3gV+KGlRgiGKfsAsM/t31hM451wrNXAoYn3gWTObCSDpKmAHgkkNAJjZvZH69wHL1XPCTPOQzex9IMvdsHPOtVWeN4ZI6gK6IkXjwiFXgGUJRgO6zSJ4a1I1I4FbIt8NuE2SAWMjx63KF4Y45zpKnjHkiuddsUMlNalyzqEEAfnbkeKNzOxlSUsDUyXNMLN7avWnYwJytcRAVx16aKwsaawY4ABP5BOz/9ChieVJC26K6ORdd00s/8U117S4J65VGjhkMQtYPvJ9OeDlykqS1gIuBLY2sze7y83s5fDP1yRdTzAEUjMgl3fRt3POJZCUeUsxDVhV0sqS5gd2Bz73RF/SCsB1wF5m9kykfJHw2RuSFgG2BB5PO2HH3CE75xw07g7ZzOZKOgSYQjDt7SIze0LSAeH+McAvgSWB0WGA757etgxwfVjWD7jSzG5NO6cHZOdcR8nzUC+NmU0GJleUjYl8/hHwo4R2M4G1856v4wNyUnKgn2+/fVPOddiWW8bKfnvbbU05V6uUZay4Gh8rhu3WWitWdtOjj7ahJ63R0fmQJX0F2IlgcHsu8A9goicWcs4VURGXRGdV895e0mHAGGBB4BvAQgSB+W+SNm1255xzLq8GPtRrubTBlh8Dw83sFGALYA0zOw4YDpxdrZEnF3LOtUsnJxfqrvMpsACwKICZ/VPSfNUaeHIh51y7FPHON6u0gHwhME3SfcDGBLmQkfQF4K0m961pmpUcaJVBg2Jlv9pll1jZL6+9tinnL7sjhg+PlZ11a+pMoYY5+wc/iJX99PLLW3b+ZmnVA7wL9tsvsfzAiy5qyfm7NXKWRaulJRc6V9KfgK8BZ5nZjLD8dYIA7ZxzhdLJd8iY2RPAEy3oi3PO1a2jA7JzzpVJER/WZVXKt063UtJij6SxYoDDLr00VvbjTTaJla3z5eS3Xx00wV9T6Hq9uqPpBfvtlznmHHjRRYWK3n6H7JzrKD5k4ZxzBdHXA7JzzhWD3yF3sKTkQElziyF5vPh3d98dK7vkgAPq75iryzHbbRcrO/2mm9rQk8Y5ceedk8uvu67FPWkvD8jOOVcQZZ5l4QHZOddRynyHnJbtbTFJv5Z0maQRFftG12jnyYWcc23RycmFJhDkP74W2E/SLsAIM/sY2LBaI08u5JxrlyIG2qxqLgyRNN3MBke+HwdsA2wPTDWzdTOco9cE5NH77hsrW2SBBWJl+4wZEysDOH6HHWJlyy65ZKys1claXH22Hzw4VjZp+vSW96Mk6o6mVxx8cOaYs+f55xcqeqfdIS8gqY+ZzQMws1MlzSJ4lXX/pvfOOedy6tgxZOCPwGbRAjO7BDgSmNOsTjnnXE+V+Y0haek3f1al/FZJpzWnS84513PFC7PZ1TPt7SSCh34NcfKuuyaWl+mtwVmTAyWNFQOccuONsbKkpO2uXHy8uLU6NkG9pGqvGhCwTOO745xz9SniUERWaXfIywBbAW9XlAu4tyk9cs65OpQ3HKcH5JuA/mY2vXKHpLua0SHnnKtHx94hm9nIGvtGVNvXE3nGiseNTO5W1/jxjepOUyXNLYbsL/msNga94he+ECv78YUX5uydc+XWsQG5TMoSjJ1zzVXmgFzex5HOOZegr5R5SyNpuKSnJT0r6diE/XtKejTc7pW0dta2SdKSCw2PfB4gaXx44islVZ1l4cmFnHPt0qiFIZL6AucDWwNrAHtIWqOi2vPAJma2FnAyYQ6fjG1j0oYsTgO6BzF/A7wCfBfYGRgL7JjUyJMLOefapYFDFusDz5rZzPC4VwE7AE92VzCz6Gyz+4DlsrZNkmcMeUgk0dDZkvbJ0bahyj5enCc5UNIDvKQFJACn7rZbrCzpLRK97Q0SrnfJE44ldQFdkaJx4Q0lwLLAS5F9s4ANahxuJHBLD9sC6QF5aUlHEFzjYpJk/00P5+PPzrnCyXOHXPF/87FDJTWpcs6hBAH523nbRqUF5N8Bi4afLwGWAl6XNAiYnnZw55xrtQbmQ54FLB/5vhzwcmUlSWsBFwJbm9mbedpWSpuHfFKV8lcl3Zl2cOeca7UGjiFPA1aVtDLwL2B3oPLNSSsA1wF7mdkzedomKUxyIZcsabFH0lgxwHFXXx0ru/ygg2JleZIbuewOGTYssfy8qVNb3JPerVEB2czmSjoEmAL0BS4ysyckHRDuHwP8ElgSGB2ed66ZDanWNu2cnlzIOddRGrkwxMwmA5MrysZEPv8I+FHWtmk8uZBzrqOUd52eJxdyznWYMi+drvmS0waJneCcvfaKVTr8ssua3Y+OkTS3GGCVQYNiZT8YPTqx7uFbbRUrW3ullWJl+44dm69zztWn7mh69y9+kTmobXLyyYWK3h2TXMhllxSMnesUZb5D9oDsnOsoZQ7IacmFhki6U9LlkpaXNFXSu5KmSVqnRjtPLuScaw/l2Aom7Q55NHACMJBgVsVPzWyYpM3Dfd9MauTJhZxzbdOngJE2o5oP9SQ9bGbrhJ//aWYrJO1L4QG5RZIWfMz+6KNY2TlTpiS2T3rz95+fTE5OdVuVcufqVHc0/cupJ2WOOd8+7oRCRe+0O+SPJG0JDABM0o5mdoOkTYBPm98955zLp8xjyGkB+QDgf4F5BAtEDpR0McHa7B83t2vOOdcDJR6yqPlQz8weMbOtzGxrM5thZj8xs4Fmtibw1Rb10TnnMpOyb0XT44UhlWPKNfgYchtN2H//WNmsN99MqJn85u+z9twzse6n8+bFyo6eODFn73qvM0fEE38ddeWVbehJ4dQdJu8945TMMedbRx9fqLDsyYWccx2lk8eQPbmQc65c+nZuQPbkQs65UinzHXJbkgslOXrbbRPLz7j55oZ2ppb9hw6NlY29szNfjLLlGvE3kg9fJz6t/Igrrkhsf8Yee8TKmjGGfOQ22ySW/2ZyrjSzrjzqjqb3/fbXmYPahof9v0JFb89l0QslBWPnOkWZ75B7/OZoSbek13LOuRbr1FwWktattgsYXKNdF9AFMHbsWLq6unraP+ecy0UlXhiSNmQxDbib5N8lA6s18uRCzrm2KXFATksu9Diwk5n9I2HfS2a2fIZzeEAuiaQHddUkPcBLepPJnLlzE9ufNmlS9o653qTuaDpt7BmZY8439j+6UNE77Q75RKqPMx/a2K4451wDlPgOOS2XxTWAJG0uqX/F7nheR+eca7My57JIe2PIYcCNBHfDj0uKJtw9rZkdc865HilxRE4bsvgxsJ6ZzZa0EnCNpJXM7FwKOWnE1SPPwo6k8eITr7suVnbCTjsltt/3O9+JlU34858zn9+5atTBS6f7mtlsADN7QdKmBEF5RTwgO+eKqIB3vlmlLQx5VdLg7i9hcN4OWAr4ehP75ZxzPVLiEYvUgLw38Gq0wMzmmtnewMZN65VzzvVUiSNyzSELM5tVY99fG98dVxZJ84uTxotPuv76zMcsQoKpZki6rrJfU5GpxwkhEo4lDQfOBfoCF5rZqIr9qwMTgHWB48zszMi+F4D3Cd4/OtfMhqSdL3dyIUlLm9lreds51xMeuFxuDbrzldQXOB8YBswCpkmaZGbRV66/BRwG7FjlMEPN7I2s50zLZbFEZRFwv6R1CFb5vZX1RM451woNzGWxPvCsmc0EkHQVsAPwWUAOb05fk5T8v3c5pd3cvwE8GNkeAJYFHgo/J5LUJekBSQ+MGzeuWjXnnGu8HGPI0VgVbtFMaMsCL0W+zwrLsjLgNkkPVhy3qrQhi58BWwBHm9ljAJKeN7OVa/bCkws559okzxhyRayKHSqpSY6ubGRmL0taGpgqaYaZ3VOrQdpDvTPD2/SzJb0EnJCzQ65DJSUHSlrskcfAhReuq31RFXEcfOTG8UlS4++pGSvKo3GzJ2YB0QRqywEvZ21sZi+Hf74m6XqCIZCaf8mpv0vMbJaZ7QbcCUwFOvNfjXOuMzRu2ts0YFVJK0uaH9gdyJSmUNIikhbt/gxsCTye1i51lkU4rWNZgoD8J+ArYflwM7s1S+ecc65VGnWDbGZzJR0CTCGY9naRmT0h6YBw/xhJgwiepy0GzJN0OLAGweK568PXSfUDrswSL9NmWRwGHAw8BYwHfmJmN4a7TwM8IDvnCqWRbwwxs8nA5IqyMZHPrxIMZVR6D1g77/k8uVAPHDF8eGL5Wbf27t9PScmBqi32SBovPu7qq2Nl31pppcT2977wQq6+uc/rmPHiJAVcgZeVJxdyznWWBq7UazVPLuSc6ygK5hdn2orGkws55zqLcmwF48mFnHMdRX3KO2ZR863TDeILSUrsyG22SSz/zeTJieX1SHqAV+3h3fE77BAre+eDD2Jl502dWm+3XGvVfd/6xK3jMsecNYd3Feo+OXe2N+ecK7Iijg1nlfaS04ckHS/pK3kO6smFnHNt06ljyMDiwEDgTkmvAhOB33ev0a7Gkws559qlzHfIaQH5bTM7CjhK0neAPYCHJD0FTAwDr+tgzRgrriZpvDhprBjglBtvjJXts9FGje6So/43wbRceeNx9inUZvZnMzuIIK/F6cA3m9Yr55zrIfVR5q1o0u6Qn6ksMLNPCXJY9O51ws65YirxkEXNO2Qz213S6pI2l9Q/ui98+Z9zzhVKmVfqpWV7OxQ4hDDbmyTP9uZaKmluMSSPF1/y1/hapTP22COx/RKLLhorG+kzghIVerw4SfHibGZpQxZdeLY351yZFPDONyvP9uac6ygljsee7c0511nKPMvCs7055zpL496p13Ke7S3i5F13jZX94ppr2tAT1y1PcqCkB3hHT5yYuf2ZI0bEyo668srM7YvqmO22i5WdftNNbehJixQvzmbmyYWccx2liNPZskpLLjRA0ihJMyS9GW5PhWUDa7Tz5ELOufbo4ORCfwDuADYN365K+NrrfYCrgWFJjTy5kHOuXYr4sC6rmgnqJT1tZl/Nu69Crw7IvW78ro3Gd3XFyupd7HHqbrsllie9Ids1RN3R9LkHrsgcc74yZM9CRe+0WRYvSvqZpGW6CyQtI+kY4KXmds055/Ir8SSL1ID8fWBJ4G5Jb0t6C7gLWAL4XpP75pxz+ZU4IqdNe3tb0rXANWY2TdKawHDgKTN7qyU9dM65PIoXZzNLG0M+AdiaIHBPBdYH7ga2AKaY2akZzpFpPOcXO+6YWH7yDTdkad5SZ//gB4nlP7388hb3xKVJmlsM9c8vnrD//rGyfceOreuYrfSrXXaJlf3y2mvb0JOYusPp89MnZh5DXnnwHoUK32mzLHYFBgMLEKzYW87M3pN0BvB3IEtAds65linzLIu0MeS5ZvapmX0APGdm7wGY2YfAvKb3zjnn8mrgPGRJwyU9LelZSccm7F9d0t8kfSzpqDxtk6QF5DmSFg4/rxc50QA8IDvniqhBD/Uk9QXOJxi2XQPYQ9IaFdXeAg4DzuxB25i0gLxxeHeMmUUD8HwEi0Occ65QGjjJYn3gWTObaWZzgKuAz71118xeM7NpwCd52yZJm2XxcZXyN4A30g6eRxEf3lXjD+/Ko9rDu6QFH3kWe7w9e3asbPS++ybWPWjChMzHbZWCPMBrjhzT2SR1EbyIo9u4cKUxBC90jq63mAVskPHQPWrryYWccx0lz/TiijQPsUMlNcnajZ60TUsuNEjSBZLOl7SkpBMlPSbpD5K+WKOdJxdyzrVHH2XfapsFLB/5vhzwcsZe9Kht2h3yxcDNwCLAncAVwLYEYyFjqDIm4smFnHPtosatDJkGrCppZeBfwO5A8sT2BrVNWxjysJmtE37+p5mtENk33cwGZ+iYB2RXCkmLPZLGigGOuOKKWFm1N1y/+Eb8cUuexPu9TN3R9KVnrskcc5Zfbdea55O0DXAO0Be4yMxOlXQAgJmNCbNfPgAsRjDzbDawRrheI9Y2rT9pd8jRIY1La+xzzrliaGCOCjObDEyuKBsT+fwqwXBEprZp0gLyjZL6m9lsMzu+u1DSKsAzeU7knHOtUMCcQZmlTXv7paT1JVmYXGgNguRCM8ws/gI655xrszIvnc6bXGgDgvSbDU8u5Op3yLD4C1x8rLI+1eYW/+ejj2JleV6oWuDkPu1WdzT91/PXZ445y668U6GitycXcs51lkKF2HzSAvJcM/sU+EDS55ILSfJcFs65wunYt07jyYWcc65l0u6QN+7OZ+HJhZxzpVDiO+SaD/UaxB/qVdh+8ODE8knTp7e0H61w9LbbJpafcfPNLe5JY7X7Aeqo738/Vnbs73/fsvM3Ud3R9JVZkzLHnC8ut32honfuxR2SlmxGR5xzriEamKC+1dKSC42StFT4eYikmcDfJb0oaZMa7Ty5kHOuLSRl3oombQx5WzPrfvXIGcD3wwUiqwFXAkOSGnlyIedc2xQvzmaWtjBkBvA/ZjZX0n1mtmFk32Nm9vUM5/CAXBJJb2iu9+3MvUnSYg/IvuDjoM03TywfffvtPe5TCdUdTv/96k2ZY84yg7YrVPhOu0M+H5gsaRRwq6RzgOuAzYHpze2ac871QAGHIrJKy2Xxf5IeAw4EVgvrrwbcAJzS9N4551xORRwbzirLK5w+AM4Mx47XJEguNMvMKl/q55xzBVDegJw3udD6wN14ciHn6pI0XlxtrDhpbHrO3LmxslNuvLH+jrVf3dH0tdemZI45Sy+9VaGitycXcs51lE4esvDkQs65kuncgDxH0sJm9gGeXMg5VwJS33Z3occ8uZBzrsOU9w7Zkws5VxB5Fpac9r3vxcpee/fdxPbnTJlSX8daq+5o+uabd2eOOUsuuUmhoneWaW/OOVcihYqxuaQlF+ov6VeSnpD0rqTXJd0n6Ycp7Ty5kHOuTcqb7i3tDvkK4HpgK+B7wCLAVcDxklYzs58nNfLkQs65dpFyZxUujLSAvJKZXRx+PkvSNDM7WdK+wJNAYkBupDwT6J0roqzJ5JMWe0DyePHP//CHWNn+Q4f2oHedqLwBOa3n/5H0bQBJ3wXegs9mXBTvft851+t1cj7kA4HfhfmPHwf2A5D0BYJMcM45VzDlvUNOy/b2iKRDgXlhcqE1JB0BzDCz37ami845l0fx7nyzyptcaAPgLjy5UGYn7rxzvOy669rQE1dWh2+1VazswzlzYmVj77wzsf2pu+0WK0sarz7p+ut70LuGqzuavvPOA5ljzsCBQ2qeT9Jw4FygL3ChmY2q2K9w/zYEmTF/aGYPhfteAN4HPiVIQ5H4hqUoTy7knOsojVo6reBA5wPDgFnANEmTzOzJSLWtgVXDbQPggvDPbkPN7I2s50wbbJlrZp+GuSw+l1wIz2XhnCukhs1DXh941sxmmtkcgim/O1TU2QG41AL3AQMlfbGnPU8LyHMkLRx+9uRCzrkSyB6Qo4vYwq0rcqBlgZci32eFZWSsY8Btkh6sOG5VnlzIOddR8kxnq1jEFjtUUpMcdTYys5clLQ1MlTTDzO6p1R9PLpRiu7XWipXd9OijbeiJ64ljttsusfz0m25qcU+aL+nhHcBxV18dKzv2u9+NlS216KKJ7Vv85vG6H+q9997jmWPOYov9T9XzSfomcKKZbRV+/38AZvbrSJ2xwF1mNjH8/jSwqZm9UnGsE4HZZnZmrf6Ud8Kec84latgY8jRgVUkrS5of2B2YVFFnErC3AhsC75rZK5IWkbQogKRFgC0J1nLUlJZcaDFJv5Z0maQRFftG12jnyYWcc20h9cm81WJmc4FDgCnAU8AfzOwJSQdIOiCsNhmYCTwL/A44KCxfBviLpEeA+4GbzezWtL6njSFPAP4BXAvsJ2kXYEQ4rrxhjQvx5ELOuTZp3MIQM5tMEHSjZWMinw04OKHdTGDtvOdLWxgy3cwGR74fRzABentgqpmtm+EcHpBdzMiNN46Vjb+n5vOOpktKEJ+UHL6oTthpp8TyjxMWgYz64x8zH/fn228fKzttUuX/uTdM3dF09uynM8ec/v2/WqhlfWl3yAtI6tM9w8LMTpU0C7gH6N/03jnnXG7lfTSW1vM/AptFC8zsEuBIIL520znn2q5DE9Sb2c8qyyRdamZ7EywVdM65QiliWs2s0saQKweKBAwF7gAws/jgUpyPITvXRmeOGBEryzO3+MDNNouVDVlllcS6I+ufVVV3NP3Pf57PHHMWWWTlQkXvtDHk5YEngAsJAquAIcBvmtwv51zJNCAYN0SZ75DTxpDXAx4EjiOY8HwX8KGZ3W1mdze7c845l1/njiHPA86WdHX457/T2jjnXHuVd5ZFpuBqZrOA3SRtC7zX3C4551zPlXnIwpMLuVJKWgRRkDdelELSYo+3Z89OrHvBHXfEyo7edttYWdJbTADOmzo1T9fqjqYfffRy5piz4IJfKlT09uEH51yHKVSMzSUtudDwyOcBksZLelTSlZKWqdHOkws559qkQx/qAacB3RmKfgO8AnwX2BkYC+yY1MiTCznn2qVjx5AlPdSdQCgh0dDnvtfgAdm5EhjflfyWoRn/+les7Iybb46VnbXnnontF15wwVjZAePHV+tG3dH0449fyxxzFlhg6UJF77Q75KUlHUHwl7SYJNl/I3h555Y45zpYoWJsLmkB+XdA93tdLgGWAl6XNAiY3sR+Oedcj0h9292FHktbGHJSZVkkudDeTeuVc871WIfeISckFwLYTNJAyJxcyDnnWqaTH+o9TDy50ESCl/2RMZ9Fxz3Uu2C//RLLD7zoohb3xLnmO2TYsFjZl5deOlZ2xBVXZD7mLuutl1h+7YMP1h1NP/nk3cwxZ775BhQqentyIedch+nQecieXMg5VzZlHrLw5ELOuQ5T3hm5ue52zexmID4jvIaksaJrH3wwzyEKx8eKXW+SlBxozMiRdR3z21/7Wl3ta+n4O2TnnCuP8gbktORCQyTdKelySctLmirpXUnTJK3Tqk4651x25X2olzbYMhr4X4JhinuBsWY2ADg23Jcomu3t+TfeaFhnnXMujaTMW9GkzkM2s3XCz/80sxWS9tWyohQ7wZF77RWrd9ill2bts3OugKrNLU4aL/7p5ZdXO0zdUXLevE8yz0Pu02e+QkXltDHkjyRtCQwATNKOZnaDpE2AT5vfPeecy6eId75ZpQ1ZHAgcCewHbAUMlfQ2wXDFT5rcN+ec64HGjSFLGi7paUnPSjo2Yb8k/Tbc/6ikdbO2TZK2MGQ6QSDu9hNJS5hZfMzBOecKoTF3yArSxp0PDANmAdMkTTKzJyPVtgZWDbcNgAuADTK2jelpcqFJ4MmFnHPF08Ahi/WBZ81sZnjcq4AdgGhQ3QG4NMwTf5+kgZK+CKyUoW2cmVXdgIeBy4FNgU3CP18JP29Sq22V43U1um4zjlmm85epr+0+f5n62u7zF6GvrdiALuCByNYV2bcrcGHk+17AeRXtbwK+Hfl+OzAkS9ukrdXJhZLfEVNf3WYcs0znz1O3t58/T93efv48dZt1/qYzs3FmNiSyRd/KnHSrXTmDo1qdLG1jPLmQc84lmwUsH/m+HPByxjrzZ2gbkykLh5nNMrPdgFsIhjCcc67TTQNWlbSypPkJ8sBXPlebBOwdzrbYkGAk4ZWMbWOanlyowrj0KrnrNuOYZTp/nrq9/fx56vb28+ep26zzt5WZzZV0CDAF6AtcZGZPSDog3D8GmAxsAzwLfADsW6tt2jlrrtRzzjnXOuVNHOqccx3GA7JzzhWEB2TnnCuIpgZkSatLOiZc631u+DmW+imst7mk/hXlwzOcIzFNnKQNJC0Wfl5I0kmS/ijpdEkDIvXml7S3pC3C7yMknSfpYEnz5b1mV52k+KuKq9ddspl9aaSs19WJ1xTWLc11FV3TArKkY4CrCCZI308wDUTAxGiiDUmHATcChwKPS9ohcpjTKo45qWL7I7Bz9/eKLlxE8NQT4FyCjHWnh2UTIvUmANsS5Om4DNgN+DvwDeDCHv8F1KmV/yAkDZA0StIMSW+G21Nh2cBIvcUk/VrSZZJGVBxjdMX3JSq2JYH7JS0uaYmKuqMkLRV+HiJpJvB3SS+GmQWjdTO9NCHrNTXruppxTWX6WeW5JhfRxCWJzwDzJZTPD/wj8v0xoH/4eSWC5Ys/Cb8/XNH2ITIu5Qaeirar2Dc98vnR8M9+wL+BvuF3de+raDsAGAXMAN4Mt6fCsoGReosBvwYuA0ZUHGN0xfclKrYlgReAxYElKuqOApYKPw8BZhJMuXkx+ncQ7rsz/PtaHpgKvEvwi3GdimNOAY4BBkXKBoVlUyNl14bn35FgTuW1wAJV/o7nAc9XbJ+Ef86sqPtY5POdwDfCz6sBD1TUvZ8gocsewEvArmH55sDf8l5Ts66rGddUpp9VnmvyLfL327QDBwFrxYTyFYGnI9+frNjfH7gVOItI4Az39QF+ShBcBodlM6uc/2pg3/DzBGBI5D+caZF6jxP8klgceJ8wAAILEgnqkfod9w8i+vNIuN7oz6ry53Ec8FeCXyCV13RU+HP8eqTs+Rr/rfQLP99X7XrD7w9HPv+zxr5M19Ss62rGNZXpZ5XnmnyL/N007cAwnODO7RaCyeDjwh/6s8DwSL07CINrpKwfcCnwaZVjL0cQcM+r/GFH6gwALgaeIxiC+ITgbvJuYO1IvZ+G5S8ChxEkB/kdwZ37CQnH7bh/EMBtwM+AZSJlyxD8kvlTpOwpoE9F232AJ4AXa/yczgIWpfovz0PDPmwGnAicA2wMnARcVlH3b8CWBENLLwI7huWb8PlfSJmuqVnX1YxrKtPPKs81+Rb5+23qwYM72g2BXQiyH21IOCRQ8R/CoCrtN0o5/rbAaSl1FgXWJkiUtEyVOl8CvhR+Hhj2df0qdTvuHwTB/x2cThDs3wbeCvt/OpEhE4L3K26R0KfhRIahEvZ/F7gPeLVGnU2B3xNkGHyMYAVUFxXDXuHPcgrBL/rVCZ4PvBP+vX4r7zU187rqvKa3w2vaqKJu5XW9HV7X/zboZ7V9hp/V0ITr2j96XcDgrNfkW+Tvtt0dKNtW8Q/irYp/6ItH6rUjePWL1MkUuCL1Vwe2IBzPj/Y3od7mCfW2rnLMzQmGoRYC/ifpmCnHTar7tSx1CfLZdg/prEnw9pttqvydRuuuARyRse7XgeOT6uY8/wZZ6ya0vSxjvUsz1lsIuDrHv4msx83Uz968+dLpBpK0r5lNqLeepIWAr5jZ41mPWc/5Fcx0OZjgF8tggoeqN4b7HjKzdcPPhwKHpNXLc8we1j2I4Bdirb6eQDCG3o/gmcP6BMNVWwBTzOzUyDEr624A3JWxbuJx6zx/rbqJL40gGPrDwpdGJNQTwZ3t5+rlOWad5696TBfR7t8InbRRZTy7p/WaVbeyHhlnumStV4S6Yb2+wMLAe8BiYflCVMyeaUbdJp4/00wjcrxcIusxm3V+3/67eW7jnCQ9Wm0XwVhyrnrNqpvnmATj+rMBzOwFSZsC10hakc8n2s5arwh155rZp8AHkp4zs/fCNh9KmldxzGbUbdb5hxC8YPg44Ggzmy7pQ4u/MGK9jPXyHLNZ53chD8j5LUPw4te3K8oF3NuDes2qm+eYr0oabMFLbTGz2ZK2I1hc8/Ue1CtC3TmSFjazDwiCQ3DxwSrNyiDXjLpNOb9lfGlE1nrNqpvnmC6i3bfoZduA8UTeoVWx78q89ZpVN+cxM810yVqvCHUJ53wn1FmKyPTCZtVt1vkT6qTONMpTr1l18xyzN2/+UM855wrCs70551xBeEB2zrmC8IDsnHMF4QHZOecKwgOyc84VxP8Hv85ONLvZd1EAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.spatial.distance import squareform\n",
    "\n",
    "plt.figure()\n",
    "sns.heatmap(squareform(w_gt_batch[idx,:].detach().cpu().numpy()), cmap = 'pink_r')\n",
    "plt.title('groundtruth')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
