{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "from smiles_rnn_distribution_learner import SmilesRnnDistributionLearner\n",
    "\n",
    "from guacamol.utils.helpers import setup_default_logger\n",
    "setup_default_logger()\n",
    "\n",
    "from atalaya import Logger\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Distribution learning benchmark for SMILES RNN',\n",
    "                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "'''\n",
    "parser.add_argument('--data_path', default = '../../data/Guacamol', help='training data path')\n",
    "parser.add_argument('--train_data', default='../../data/Guacamol/train_props.npz',\n",
    "                        help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--valid_data', default='../../data/Guacamol/valid_props_smaller.npz',\n",
    "                        help='Full path to SMILES file containing validation data')\n",
    "\n",
    "parser.add_argument('--output_dir', default='./output/Guacamol/', help='Output directory')\n",
    "\n",
    "\n",
    "## qm9\n",
    "parser.add_argument('--data_path', default = '../../data/QM9', help='training data path')\n",
    "parser.add_argument('--valid_data', default='', help='Full path to SMILES file containing validation data')\n",
    "parser.add_argument('--train_data', default='../../data/QM9/QM9_clean_smi_train_smile.npz',\n",
    "                    help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--output_dir', default='./output/QM9/debug', help='Output directory')\n",
    "\n",
    "'''\n",
    "\n",
    "parser.add_argument('--data_path', default = '../../data/ZINC', help='training data path')\n",
    "parser.add_argument('--valid_data', default='', help='Full path to SMILES file containing validation data')\n",
    "parser.add_argument('--train_data', default='../../data/ZINC/ZINC_clean_smi_train_smile.npz',\n",
    "                    help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--output_dir', default='./output/ZINC/debug', help='Output directory')\n",
    "\n",
    "\n",
    "parser.add_argument('--entropy_lambda', default=0.0, type=float, help='weighting of the entropy term in the loss')\n",
    "parser.add_argument('--batch_size', default=60, type=int, help='Size of a mini-batch for gradient descent')\n",
    "parser.add_argument('--valid_every', default=1000, type=int, help='Validate every so many batches')\n",
    "parser.add_argument('--print_every', default=10, type=int, help='Report every so many batches')\n",
    "parser.add_argument('--n_epochs', default=100, type=int, help='Number of training epochs')\n",
    "parser.add_argument('--max_len', default=100, type=int, help='Max length of a SMILES string')\n",
    "parser.add_argument('--hidden_size', default=512, type=int, help='Size of hidden layer')\n",
    "parser.add_argument('--n_layers', default=3, type=int, help='Number of layers for training')\n",
    "parser.add_argument('--rnn_dropout', default=0.2, type=float, help='Dropout value for RNN')\n",
    "parser.add_argument('--lr', default=1e-3, type=float, help='RNN learning rate')\n",
    "parser.add_argument('--seed', default=42, type=int, help='Random seed')\n",
    "\n",
    "args,_ = parser.parse_known_args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.output_dir is None:\n",
    "    args.output_dir = os.path.dirname(os.path.realpath(__file__))\n",
    "    print('output directory is:', args.output_dir)\n",
    "\n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "     \n",
    "graph_logger = Logger(\n",
    "    name=\"ZINC_batchszie_\" + str(args.batch_size) + '_lstm',  # name of the logger\n",
    "    path=args.output_dir + \"logs\",        # path to logs\n",
    "    verbose=True,       # logger in verbose mode\n",
    "    grapher=\"visdom\",\n",
    "    server=\"http://send2.visdom.xyz\",\n",
    "    port=8999\n",
    ")\n",
    "        \n",
    "graph_logger.add_parameters(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = SmilesRnnDistributionLearner(data_set = 'ZINC',\n",
    "                                           graph_logger= graph_logger,\n",
    "                                           output_dir=args.output_dir,\n",
    "                                           n_epochs=args.n_epochs,\n",
    "                                           hidden_size=args.hidden_size,\n",
    "                                           n_layers=args.n_layers,\n",
    "                                           max_len=args.max_len,\n",
    "                                           batch_size=args.batch_size,\n",
    "                                           rnn_dropout=args.rnn_dropout,\n",
    "                                           lr=args.lr,\n",
    "                                           valid_every=args.valid_every)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO : CUDA enabled:\tTrue\n",
      "CUDA enabled:\tTrue\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-e5b9446cb93a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_set\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidation_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'All done, your trained model is in {args.output_dir}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/conditional-molecules/code/gencond/lstm/smiles_rnn_distribution_learner.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, data_path, training_set, validation_set)\u001b[0m\n\u001b[1;32m     58\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 60\u001b[0;31m             \u001b[0mtrain_seqs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_prop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_smiles_and_properties\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraining_set\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     61\u001b[0m             \u001b[0;31m#sample_indexs = np.arange(train_seqs.shape[0])\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m             \u001b[0;31m#random.shuffle(sample_indexs)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/conditional-molecules/code/gencond/lstm/rnn_utils.py\u001b[0m in \u001b[0;36mload_smiles_and_properties\u001b[0;34m(smiles_file, rm_invalid, rm_duplicates, max_len)\u001b[0m\n\u001b[1;32m     91\u001b[0m     \"\"\"\n\u001b[1;32m     92\u001b[0m     \u001b[0msmiles_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msmiles_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m     \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_smiles_from_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msmiles_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrm_invalid\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrm_invalid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrm_duplicates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrm_duplicates\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_len\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_len\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0msmiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/conditional-molecules/code/gencond/lstm/rnn_utils.py\u001b[0m in \u001b[0;36mload_smiles_from_list\u001b[0;34m(smiles_list, rm_invalid, rm_duplicates, max_len)\u001b[0m\n\u001b[1;32m    164\u001b[0m         \u001b[0menc_smi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBEGIN\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmol\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEND\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    165\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menc_smi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m             \u001b[0msequences\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchar_idx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0menc_smi\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    168\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0msequences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "\n",
    "trainer.train(args.data_path, training_set=args.train_data, validation_set = args.valid_data)\n",
    "print(f'All done, your trained model is in {args.output_dir}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rnn_utils import get_tensor_dataset, load_smiles_and_properties, set_random_seed,load_smiles_from_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seqs, train_prop = load_smiles_and_properties(args.train_data, args.max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(244444, 102)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_seqs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(244456, 9)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_prop.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles_list, properties = np.load(args.train_data, allow_pickle=True).values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(244456,)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smiles_list.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(244456, 9)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "properties.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles, _ = load_smiles_from_list(smiles_list, rm_invalid=False, rm_duplicates=False, max_len=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(244456, 102)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smiles.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_smiles_and_properties(smiles_file, rm_invalid=False, rm_duplicates=False, max_len=100):\n",
    "    \"\"\" \n",
    "    Load a saved list of smiles and associated precomputed properties.\n",
    "    \"\"\"\n",
    "    print(rm_invalid)\n",
    "    smiles_list, properties = np.load(smiles_file, allow_pickle=True).values()\n",
    "    smiles, _ = load_smiles_from_list(smiles_list, rm_invalid=rm_invalid, rm_duplicates=rm_duplicates, max_len=max_len)\n",
    "    return smiles, properties"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
