{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "from torch.distributions.categorical import Categorical\n",
    "\n",
    "sys.path.append('../../')\n",
    "from guacamol.utils.helpers import setup_default_logger\n",
    "\n",
    "from smiles_rnn_distribution_learner import SmilesRnnDistributionLearner\n",
    "\n",
    "setup_default_logger()\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Distribution learning benchmark for SMILES RNN',\n",
    "                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "parser.add_argument('--train_data', default='../../data/QM9/QM9_clean_smi_train_smile.npz',\n",
    "                    help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--valid_data', default='../../data/QM9/QM9_clean_smi_test_smile.npz',\n",
    "                    help='Full path to SMILES file containing validation data')\n",
    "parser.add_argument('--prop_model', default=\"../../data/QM9/prior.pkl.gz\", help='Saved model for properties distribution')    \n",
    "parser.add_argument('--output_dir', default='./output/Qm9/', help='Output directory')\n",
    "\n",
    "parser.add_argument('--batch_size', default=512, type=int, help='Size of a mini-batch for gradient descent')\n",
    "parser.add_argument('--valid_every', default=1000, type=int, help='Validate every so many batches')\n",
    "parser.add_argument('--print_every', default=10, type=int, help='Report every so many batches')\n",
    "parser.add_argument('--n_epochs', default=10, type=int, help='Number of training epochs')\n",
    "parser.add_argument('--max_len', default=280, type=int, help='Max length of a SMILES string')\n",
    "parser.add_argument('--hidden_size', default=512, type=int, help='Size of hidden layer')\n",
    "parser.add_argument('--n_layers', default=3, type=int, help='Number of layers for training')\n",
    "parser.add_argument('--rnn_dropout', default=0.2, type=float, help='Dropout value for RNN')\n",
    "parser.add_argument('--lr', default=1e-3, type=float, help='RNN learning rate')\n",
    "parser.add_argument('--seed', default=42, type=int, help='Random seed')\n",
    "\n",
    "args,_ = parser.parse_known_args()\n",
    "\n",
    "args.batch_size = 6\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. lprepare the training and validation data\n",
    "import numpy as np \n",
    "import random \n",
    "from rnn_utils import get_tensor_dataset, load_smiles_and_properties, set_random_seed\n",
    "\n",
    "if args.output_dir is None:\n",
    "    args.output_dir = os.path.dirname(os.path.realpath(__file__))\n",
    "   \n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "     \n",
    "        \n",
    "\n",
    "prop_model = pickle.load(gzip.open(args.prop_model))\n",
    "\n",
    "\n",
    "\n",
    "training_set=args.train_data\n",
    "validation_set=args.valid_data\n",
    "max_len = 100\n",
    "train_seqs, train_prop = load_smiles_and_properties(training_set, max_len)\n",
    "sample_indexs = np.arange(train_seqs.shape[0])\n",
    "random.shuffle(sample_indexs)\n",
    "train_x, train_y = train_seqs[sample_indexs[10000:],:], train_prop[sample_indexs[10000:],:]\n",
    "valid_x, valid_y = train_seqs[sample_indexs[:10000],:], train_prop[sample_indexs[:10000],:]\n",
    "\n",
    "if prop_model is not None:\n",
    "    train_y = prop_model.transform(train_y)\n",
    "    valid_y = prop_model.transform(valid_y)\n",
    "    \n",
    "    \n",
    "train_set = get_tensor_dataset(train_x, train_y)\n",
    "valid_set = get_tensor_dataset(valid_x, valid_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import logging\n",
    "\n",
    "PROPERTY_SIZE = 9\n",
    "hidden_size=512\n",
    "n_layers=3\n",
    "rnn_dropout=0.2\n",
    "batch_size=6\n",
    "num_workers=0\n",
    "\n",
    "\n",
    "cuda_available = torch.cuda.is_available()\n",
    "device_str = 'cuda' if cuda_available else 'cpu'\n",
    "device = torch.device(device_str)\n",
    "\n",
    "set_random_seed(42, device)\n",
    "\n",
    "\n",
    "data_loader = DataLoader(train_set,\n",
    "                                 batch_size=batch_size,\n",
    "                                 shuffle=True,\n",
    "                                 num_workers=num_workers,\n",
    "                                 pin_memory=True)\n",
    "\n",
    "for batch_index, batch_all in enumerate(data_loader):\n",
    "    batch = batch_all[:-1]\n",
    "    properties = batch_all[-1].to(device)\n",
    "    inp, tgt = batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cdist\n",
    "K = cdist(properties, torch.tensor(train_y).float(), metric='euclidean')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. sample according to knn "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_size = 10\n",
    "selected = np.argsort(K, axis=1)[:,:sample_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "selected.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.5999, -0.4318,  0.2557, -0.5177, -0.9237,  0.6546, -0.1682,\n",
       "         0.8152,  0.6067], device='cuda:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "properties[0,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1.59989955, -0.43184623,  0.25568422, -0.51772474, -0.92373615,\n",
       "        0.65461782, -0.16818697,  0.81516874,  0.60673663])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_y[selected[0,0],:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1.34049748, -0.2083466 ,  0.27714319, -0.52015953, -0.76480597,\n",
       "        0.57043644, -0.39001386,  1.03966174,  0.36786741])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_y[selected[0,1],:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 113885)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "K.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAFLZJREFUeJzt3X+MZeV93/H3J0vAMa4dEkaVs8uy62QdhbQSpJN1Wze0igGv5WjXf9hiLbkiFdLWEds4RVWDmwjCRpEcp7JaqavYyN7WSYO3GBJpVG1KSCBprRZ7Z4GY7JKthzWB2bplY6hd1wRY+PaPexZdbmeYMzN35p6Z835JI86P58x8Z7n3c555znPOTVUhSeqH75l0AZKk9WPoS1KPGPqS1COGviT1iKEvST1i6EtSjxj6ktQjhr4k9YihL0k9ctGkCxh1+eWX144dOyZdhiRtKCdOnPjLqppaql3nQn/Hjh3Mzs5OugxJ2lCS/EWbdg7vSFKPGPqS1COGviT1iKEvST1i6EtSj7QK/SR7kpxOMpfktgX2fzTJ40keS/KlJFc123ckeaHZ/liST4/7F5AktbfklM0kW4DDwPXAPHA8yUxVnRpqdndVfbppvxf4FLCn2fdkVV093rIlSSvRpqe/G5irqjNV9RJwFNg33KCqvj20eingZzBKUge1Cf2twDND6/PNttdJckuSJ4FPAj8/tGtnkkeT/EmSn1pVtZKkVRnbhdyqOlxVPwz8IvDLzeZvANur6hrgVuDuJG8dPTbJgSSzSWbPnTs3rpIkcmfInZl0GVJntAn9s8AVQ+vbmm2LOQp8AKCqXqyqbzbLJ4AngXeOHlBVd1XVdFVNT00t+egISdIKtQn948CuJDuTXAzsB2aGGyTZNbT6fuBrzfap5kIwSd4B7ALOjKNwaSH26qU3tuTsnao6n+QgcD+wBThSVSeTHAJmq2oGOJjkOuBl4Hngpubwa4FDSV4GXgU+WlXPrcUvon7InaHucJ6AtFKtnrJZVceAYyPbbh9a/tgix90H3LeaAiVJ49O5RytLK+GwjtSOj2FQpzjbRlpb9vTVeYudBBzfl5bP0NeG418C0soZ+trQ3ugE4MlB+v8Z+tp0DHtpcYa+OsvwlsbP2TuS1COGvnrHvyDUZ4a+JPWIoa9esHcvDRj6ktQjhr4k9Yihr17yGT/qK+fpa91cCNm6oxZdHm27Fj9f6jN7+poog1haX4a+1t1GCPqNUKO0Eoa+OmNSQev4vvrE0FevGfbqG0NfknrE0Ne6sEctdUOr0E+yJ8npJHNJbltg/0eTPJ7ksSRfSnLV0L6PN8edTvLecRYvjVObE5MnL210S87TT7IFOAxcD8wDx5PMVNWpoWZ3V9Wnm/Z7gU8Be5rw3w/8OPBDwB8meWdVvTLm30Mai+UEv5/Pq42oTU9/NzBXVWeq6iXgKLBvuEFVfXto9VLgwrthH3C0ql6sqq8Dc833kyRNQJs7crcCzwytzwPvGm2U5BbgVuBi4KeHjn145NitCxx7ADgAsH379jZ1S2vOHr02o7FdyK2qw1X1w8AvAr+8zGPvqqrpqpqempoaV0mSpBFtevpngSuG1rc12xZzFPjNFR6rTcYLn1K3tOnpHwd2JdmZ5GIGF2Znhhsk2TW0+n7ga83yDLA/ySVJdgK7gK+svmxtBJsp8L1rV5vFkj39qjqf5CBwP7AFOFJVJ5McAmaragY4mOQ64GXgeeCm5tiTSe4BTgHngVucuSNJk5Oqbl2kmp6ertnZ2UmXoTHY7D1jL/CqS5KcqKrppdp5R660Qpv9pKbNydCXpB7xk7OkMRju9Tvsoy6zp6+xcshD6jZDX5J6xNCXpB5xTF+r0vfn0zicpY3G0NdYGH7SxmDoa+w8AUjd5Zi+JPWIoS9JPWLoS1KPGPqS1COGviT1iKEvjZkfuKIuM/QlqUcMfUnqEUNfK+YQhrTxGPrSGnJ8X11j6EtSj7QK/SR7kpxOMpfktgX235rkVJKvJvmjJFcO7XslyWPN18w4i5ckLc+SD1xLsgU4DFwPzAPHk8xU1amhZo8C01X13SQ/B3wSuLHZ90JVXT3muqXOc1hHXdSmp78bmKuqM1X1EnAU2DfcoKoeqqrvNqsPA9vGW6a6wiCTNrY2ob8VeGZofb7Ztpibgd8fWn9TktkkDyf5wApqlCSNyVifp5/kI8A08PeHNl9ZVWeTvAN4MMnjVfXkyHEHgAMA27dvH2dJWgP29qWNq01P/yxwxdD6tmbb6yS5DvglYG9VvXhhe1Wdbf57Bvhj4JrRY6vqrqqarqrpqampZf0CkqT22oT+cWBXkp1JLgb2A6+bhZPkGuAzDAL/2aHtlyW5pFm+HHg3MHwBWJK0jpYc3qmq80kOAvcDW4AjVXUyySFgtqpmgN8A3gJ8MQnA01W1F/gx4DNJXmVwgvnEyKwfSdI6SlVNuobXmZ6ertnZ2UmXoUU4nr9ydcfgvZY789qyNC5JTlTV9FLtvCNXknpkrLN3JC3Ov5LUBfb0JalHDH1J6hFDX5J6xNCXpB4x9CWpR5y9oyU562T8LvybOl9f682eviT1iKEvST1i6EtSjxj60gR5vUTrzdCXpB4x9CWpRwx9SeoRQ1+SesTQl6QeMfQlqUd8DIMW5XRCafOxpy9NWO6MJ1itG0Nf6hBPAFprrUI/yZ4kp5PMJbltgf23JjmV5KtJ/ijJlUP7bkrytebrpnEWL20mhr3Ww5Jj+km2AIeB64F54HiSmao6NdTsUWC6qr6b5OeATwI3JvkB4A5gGijgRHPs8+P+RaTNZPgE4OOXNU5tevq7gbmqOlNVLwFHgX3DDarqoar6brP6MLCtWX4v8EBVPdcE/QPAnvGULklarjahvxV4Zmh9vtm2mJuB31/hsZKkNTTWC7lJPsJgKOc3lnncgSSzSWbPnTs3zpK0DI4pS5tfm9A/C1wxtL6t2fY6Sa4DfgnYW1UvLufYqrqrqqaranpqaqpt7ZKkZWoT+seBXUl2JrkY2A/MDDdIcg3wGQaB/+zQrvuBG5JcluQy4IZmmzrOXr+0OS05e6eqzic5yCCstwBHqupkkkPAbFXNMBjOeQvwxSQAT1fV3qp6LsmvMjhxAByqqufW5DfRWBj20ubW6jEMVXUMODay7fah5eve4NgjwJGVFihJGh/vyJWkHjH0JalHDH1J6hFDX5J6xNCXpB7xQ1TkNE2pR+zpSx3nSVnjZOhLUo8Y+tIG4CdqaVwMfUnqEUNfknrE0JekHjH0pQ3EcX2tlqEvbTBe1NVqGPqS1COGvrSB2evXchn6ktQjhr4k9YihL0k9Yuj3nOPBUr+0Cv0ke5KcTjKX5LYF9l+b5JEk55N8cGTfK0kea75mxlW4VscLgBuf//+0Eks+Tz/JFuAwcD0wDxxPMlNVp4aaPQ38LPDPFvgWL1TV1WOoVZK0Sm0+RGU3MFdVZwCSHAX2Aa+FflU91ex7dQ1qlCSNSZvhna3AM0Pr8822tt6UZDbJw0k+sKzqJEljtR4fl3hlVZ1N8g7gwSSPV9WTww2SHAAOAGzfvn0dSpKkfmrT0z8LXDG0vq3Z1kpVnW3+ewb4Y+CaBdrcVVXTVTU9NTXV9ltrhbwAKPVXm9A/DuxKsjPJxcB+oNUsnCSXJbmkWb4ceDdD1wIkSetrydCvqvPAQeB+4Angnqo6meRQkr0ASX4yyTzwIeAzSU42h/8YMJvkT4GHgE+MzPqRJK2jVmP6VXUMODay7fah5eMMhn1Gj/uvwN9cZY2SWrgwbFd31IQrUZd5R64k9YihL0k9YuhLUo8Y+tIm45RcvRFDX9oEFgt6H6ynUYa+JPWIoS9JPWLoS1KPGPqS1COGviT1yHo8Wlkd4AwOSWBPX5J6xdCXpB4x9KVNaPSmLIf3dIGhL0k94oXcTc4enqRh9vSlnvA5PAJDX5J6xdCXesbefr8Z+pLUI4a+JPVIq9BPsifJ6SRzSW5bYP+1SR5Jcj7JB0f23ZTka83XTeMqXJK0fEuGfpItwGHgfcBVwIeTXDXS7GngZ4G7R479AeAO4F3AbuCOJJetvmxJ0kq06envBuaq6kxVvQQcBfYNN6iqp6rqq8CrI8e+F3igqp6rqueBB4A9Y6hbb8CpeZIW0yb0twLPDK3PN9vaWM2xWiWDX4uxY9BfnbiQm+RAktkks+fOnZt0OZK0abUJ/bPAFUPr25ptbbQ6tqruqqrpqpqemppq+a0lScvVJvSPA7uS7ExyMbAfmGn5/e8HbkhyWXMB94Zmm6SOcKinX5YM/ao6DxxkENZPAPdU1ckkh5LsBUjyk0nmgQ8Bn0lysjn2OeBXGZw4jgOHmm2SOsbg74dWT9msqmPAsZFttw8tH2cwdLPQsUeAI6uoUZI0Jp24kCtJWh+GviT1iKEv9Zjj+P1j6EtSjxj6ktQjhr4k9Yihv8k4Rqtx8HW0ebWap6/u802qcfB1tPnZ05ekHjH0JalHDH1J6hFDfxNwHFZSW4a+JPWIoS9pQT5nf3My9CWpRwx9SeoRb86S9IaGh3jqjppgJRoHe/qSWnOMf+Ozp7+B+QaUtFyG/gZk2EtaKYd3JK2InY+NqVXoJ9mT5HSSuSS3LbD/kiT/odn/5SQ7mu07kryQ5LHm69PjLV+StBxLDu8k2QIcBq4H5oHjSWaq6tRQs5uB56vqR5LsB34duLHZ92RVXT3muiVNiD38ja1NT383MFdVZ6rqJeAosG+kzT7g883yvcB7kvjKkKSOaRP6W4Fnhtbnm20Ltqmq88C3gB9s9u1M8miSP0nyUwv9gCQHkswmmT137tyyfgFJk+OjGjaetb6Q+w1ge1VdA9wK3J3kraONququqpququmpqak1LkmS+qtN6J8Frhha39ZsW7BNkouAtwHfrKoXq+qbAFV1AngSeOdqi5bUPfb6N4Y2oX8c2JVkZ5KLgf3AzEibGeCmZvmDwINVVUmmmgvBJHkHsAs4M57SJUnLteTsnao6n+QgcD+wBThSVSeTHAJmq2oG+Bzw20nmgOcYnBgArgUOJXkZeBX4aFU9txa/yGZnD0rSOLS6I7eqjgHHRrbdPrT8V8CHFjjuPuC+VdYoqePslGwc3pErac14Mugen70jaawM+m6zpy9p3TnTZ3Ls6UtaU6Ph7gexTJY9fUnqEUN/A/DPYG0mvp4ny+GdDvPNIWnc7OlL6gQ7OevDnr6kiTHo1589/Y7yzSBpLdjTl9QZw52d4amdF7Y73XP17OlL6qTFbuDyxq7VMfQldZoBP14O73SML3Bpcb4/Vs+eviT1iD39DvAilbR8i1301Ruzpz9hwy9c/3SVVs4LvO0Y+pI2PDtP7Tm8MyG+MKW1Mzxk6vDp69nTnwADX1ofo+81h4Ba9vST7AH+NbAF+GxVfWJk/yXAbwF/C/gmcGNVPdXs+zhwM/AK8PNVdf/Yqt9g+v5ikyZpqSGgvvwlsGToJ9kCHAauB+aB40lmqurUULObgeer6keS7Ad+HbgxyVXAfuDHgR8C/jDJO6vqlXH/IpK0Gn35hK82Pf3dwFxVnQFIchTYBwyH/j7gV5rle4F/kyTN9qNV9SLw9SRzzff7b+Mpf+Owly9tLIu9Zzf6yaBN6G8FnhlanwfetVibqjqf5FvADzbbHx45duuKq+0Yg1zqn3G97yd18ujE7J0kB4ADzep3kpxexbe7HPjL1Vc1Vl2sCaxrObpYE3Szri7WBB2rK7/y2sljXHVd2aZRm9A/C1wxtL6t2bZQm/kkFwFvY3BBt82xVNVdwF1tCl5Kktmqmh7H9xqXLtYE1rUcXawJullXF2sC67qgzZTN48CuJDuTXMzgwuzMSJsZ4KZm+YPAg1VVzfb9SS5JshPYBXxlPKVLkpZryZ5+M0Z/ELifwZTNI1V1MskhYLaqZoDPAb/dXKh9jsGJgabdPQwu+p4HbnHmjiRNTqsx/ao6Bhwb2Xb70PJfAR9a5NhfA35tFTUu11iGicasizWBdS1HF2uCbtbVxZrAugDIYBRGktQHPoZBknpk04R+kj1JTieZS3LbpOsBSHIkybNJ/mzStVyQ5IokDyU5leRkko9NuiaAJG9K8pUkf9rUdeekaxqWZEuSR5P8x0nXApDkqSSPJ3ksyeyk67kgyfcnuTfJnyd5Isnf6UBNP9r8O134+naSX+hAXf+0ea3/WZIvJHnTuvzczTC80zwq4r8z9KgI4MMjj4qYRF3XAt8Bfquq/sYka7kgyduBt1fVI0n+GnAC+EAH/q0CXFpV30nyvcCXgI9V1cNLHLouktwKTANvraqf6UA9TwHTVdWZeecAST4P/Jeq+mwz2+/NVfW/J13XBU1WnAXeVVV/McE6tjJ4jV9VVS80E16OVdW/W+ufvVl6+q89KqKqXgIuPCpioqrqPzOYzdQZVfWNqnqkWf4/wBN04C7pGvhOs/q9zVcneiRJtgHvBz476Vq6LMnbgGsZzOajql7qUuA33gM8OcnAH3IR8H3NvU1vBv7HevzQzRL6Cz0qYuJB1nVJdgDXAF+ebCUDzRDKY8CzwANV1Ym6gH8F/HPg1UkXMqSAP0hyormjvQt2AueAf9sMhX02yaWTLmrEfuALky6iqs4C/xJ4GvgG8K2q+oP1+NmbJfS1TEneAtwH/EJVfXvS9QBU1StVdTWDO7d3J5n4kFiSnwGeraoTk65lxN+rqp8A3gfc0gwlTtpFwE8Av1lV1wD/F+jE9TWAZrhpL/DFDtRyGYPRiJ0MnkB8aZKPrMfP3iyh3+pxDxpoxszvA36nqn530vWMaoYEHgL2TLoW4N3A3mYM/Sjw00n+/WRLeq2nSFU9C/wegyHOSZsH5of+QruXwUmgK94HPFJV/2vShQDXAV+vqnNV9TLwu8DfXY8fvFlCv82jIsRrF0w/BzxRVZ+adD0XJJlK8v3N8vcxuCj/55OtCqrq41W1rap2MHhdPVhV69IjW0ySS5uL8DTDJzcAE58hVlX/E3gmyY82m97D6x/BPmkfpgNDO42ngb+d5M3Ne/I9DK6vrblOPGVztRZ7VMSEyyLJF4B/AFyeZB64o6o+N9mqeDfwD4HHm/FzgH/R3HU9SW8HPt/Mrvge4J6q6sT0yA7668DvDbKCi4C7q+o/Tbak1/wT4HeaztcZ4B9NuB7gtZPj9cA/nnQtAFX15ST3Ao8weETNo6zTnbmbYsqmJKmdzTK8I0lqwdCXpB4x9CWpRwx9SeoRQ1+SesTQl6QeMfQlqUcMfUnqkf8Hl1oiF0uwaawAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.hist(K[0,:],200, density=True, facecolor='g')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "index 29658 is out of bounds for axis 0 with size 5",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-73e4e151967e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mK\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mselected\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m: index 29658 is out of bounds for axis 0 with size 5"
     ]
    }
   ],
   "source": [
    "K[selected]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 10)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "selected.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
