{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "from smiles_rnn_distribution_learner_augmented_data import SmilesRnnDistributionLearner\n",
    "\n",
    "from guacamol.utils.helpers import setup_default_logger\n",
    "setup_default_logger()\n",
    "\n",
    "from atalaya import Logger\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Distribution learning benchmark for SMILES RNN',\n",
    "                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "'''\n",
    "parser.add_argument('--data_path', default = '../../data/Guacamol', help='training data path')\n",
    "parser.add_argument('--train_data', default='../../data/Guacamol/train_props.npz',\n",
    "                        help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--valid_data', default='../../data/Guacamol/valid_props_smaller.npz',\n",
    "                        help='Full path to SMILES file containing validation data')\n",
    "\n",
    "parser.add_argument('--output_dir', default='./output/Guacamol/', help='Output directory')\n",
    "\n",
    "\n",
    "'''\n",
    "parser.add_argument('--output_dir', default='./output/QM9/classic_data_augmentation', help='Output directory')\n",
    "parser.add_argument('--augmented_data', default='../../data/QM9/smiles_data_augmentation.npz')\n",
    "parser.add_argument('--data_path', default = '../../data/QM9', help='training data path')\n",
    "parser.add_argument('--valid_data', default='', help='Full path to SMILES file containing validation data')\n",
    "parser.add_argument('--train_data', default='../../data/QM9/QM9_clean_smi_train_smile.npz',\n",
    "                    help='Full path to SMILES file containing training data')\n",
    "\n",
    "\n",
    "parser.add_argument('--entropy_lambda', default=0.0, type=float, help='weighting of the entropy term in the loss')\n",
    "parser.add_argument('--batch_size', default=20, type=int, help='Size of a mini-batch for gradient descent')\n",
    "parser.add_argument('--valid_every', default=1000, type=int, help='Validate every so many batches')\n",
    "parser.add_argument('--print_every', default=10, type=int, help='Report every so many batches')\n",
    "parser.add_argument('--n_epochs', default=100, type=int, help='Number of training epochs')\n",
    "parser.add_argument('--max_len', default=100, type=int, help='Max length of a SMILES string')\n",
    "parser.add_argument('--hidden_size', default=512, type=int, help='Size of hidden layer')\n",
    "parser.add_argument('--n_layers', default=3, type=int, help='Number of layers for training')\n",
    "parser.add_argument('--rnn_dropout', default=0.2, type=float, help='Dropout value for RNN')\n",
    "parser.add_argument('--lr', default=1e-3, type=float, help='RNN learning rate')\n",
    "parser.add_argument('--seed', default=42, type=int, help='Random seed')\n",
    "\n",
    "args,_ = parser.parse_known_args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.output_dir is None:\n",
    "    args.output_dir = os.path.dirname(os.path.realpath(__file__))\n",
    "    print('output directory is:', args.output_dir)\n",
    "\n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "     \n",
    "graph_logger = Logger(\n",
    "    name=\"QM9_batchszie_\" + str(args.batch_size) + '_classic_data_augmentation',  # name of the logger\n",
    "    path=args.output_dir + \"logs\",        # path to logs\n",
    "    verbose=True,       # logger in verbose mode\n",
    "    grapher=\"visdom\",\n",
    "    server=\"http://send2.visdom.xyz\",\n",
    "    port=8999\n",
    ")\n",
    "        \n",
    "graph_logger.add_parameters(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = SmilesRnnDistributionLearner(data_set = 'QM9',\n",
    "                                           graph_logger= graph_logger,\n",
    "                                           output_dir=args.output_dir,\n",
    "                                           n_epochs=args.n_epochs,\n",
    "                                           hidden_size=args.hidden_size,\n",
    "                                           n_layers=args.n_layers,\n",
    "                                           max_len=args.max_len,\n",
    "                                           batch_size=args.batch_size,\n",
    "                                           rnn_dropout=args.rnn_dropout,\n",
    "                                           lr=args.lr,\n",
    "                                           valid_every=args.valid_every)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train(args.data_path, training_set=args.train_data, augmented_data = args.augmented_data)\n",
    "print(f'All done, your trained model is in {args.output_dir}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
