{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import gzip\n",
    "import argparse\n",
    "import os\n",
    "import sys\n",
    "from smiles_rnn_distribution_learner import SmilesRnnDistributionLearner\n",
    "\n",
    "from atalaya import Logger\n",
    "graph_logger = Logger(\n",
    "    name=\"QM9_batchszie_200_lstm\",         # name of the logger\n",
    "    path=\"logs\",        # path to logs\n",
    "    verbose=True,       # logger in verbose mode\n",
    "    grapher=\"visdom\",\n",
    "    server=\"http://send2.visdom.xyz\",\n",
    "    port=8999\n",
    ")\n",
    "\n",
    "sys.path.append('../../')\n",
    "from guacamol.utils.helpers import setup_default_logger\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<atalaya.logger.Logger at 0x7f373d43c5f8>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph_logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description='Distribution learning benchmark for SMILES RNN',\n",
    "                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "parser.add_argument('--data_path', default='../../data/QM9')\n",
    "parser.add_argument('--train_data', default='../../data/QM9/QM9_clean_smi_train_smile.npz',\n",
    "                    help='Full path to SMILES file containing training data')\n",
    "parser.add_argument('--valid_data', default='',\n",
    "                    help='Full path to SMILES file containing validation data')\n",
    "parser.add_argument('--batch_size', default=200, type=int, help='Size of a mini-batch for gradient descent')\n",
    "parser.add_argument('--valid_every', default=1000, type=int, help='Validate every so many batches')\n",
    "parser.add_argument('--print_every', default=10, type=int, help='Report every so many batches')\n",
    "parser.add_argument('--n_epochs', default=100, type=int, help='Number of training epochs')\n",
    "parser.add_argument('--max_len', default=100, type=int, help='Max length of a SMILES string')\n",
    "parser.add_argument('--hidden_size', default=512, type=int, help='Size of hidden layer')\n",
    "parser.add_argument('--n_layers', default=3, type=int, help='Number of layers for training')\n",
    "parser.add_argument('--rnn_dropout', default=0.2, type=float, help='Dropout value for RNN')\n",
    "parser.add_argument('--lr', default=1e-3, type=float, help='RNN learning rate')\n",
    "parser.add_argument('--seed', default=42, type=int, help='Random seed')\n",
    "#parser.add_argument('--prop_model', default=\"../../data/QM9/prior.pkl.gz\", help='Saved model for properties distribution')    \n",
    "parser.add_argument('--output_dir', default='./output/QM9/', help='Output directory')\n",
    "args,_ = parser.parse_known_args()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.output_dir is None:\n",
    "    args.output_dir = os.path.dirname(os.path.realpath(__file__))\n",
    "\n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "\n",
    "graph_logger.add_parameters(args)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CUDA enabled:\tTrue\n",
      "EPOCH 1\n",
      "VALID | elapsed: 0:00:05 | epoch: 1/100 (0.0%) | molecules: 200 | valid_loss: 3.6126\n",
      "\n",
      "model_1_3.613\n",
      "TRAIN | elapsed: 0:00:08 | epoch|batch : 1|10 (0.0%) | molecules: 2200 | mols/sec: 272.09 | train_loss: 2.8474\n",
      "TRAIN | elapsed: 0:00:11 | epoch|batch : 1|20 (0.0%) | molecules: 4200 | mols/sec: 380.64 | train_loss: 2.1033\n",
      "TRAIN | elapsed: 0:00:14 | epoch|batch : 1|30 (0.1%) | molecules: 6200 | mols/sec: 437.04 | train_loss: 1.8999\n",
      "TRAIN | elapsed: 0:00:17 | epoch|batch : 1|40 (0.1%) | molecules: 8200 | mols/sec: 474.78 | train_loss: 1.8150\n",
      "TRAIN | elapsed: 0:00:20 | epoch|batch : 1|50 (0.1%) | molecules: 10200 | mols/sec: 492.70 | train_loss: 1.7339\n",
      "TRAIN | elapsed: 0:00:24 | epoch|batch : 1|60 (0.1%) | molecules: 12200 | mols/sec: 508.06 | train_loss: 1.7188\n",
      "TRAIN | elapsed: 0:00:27 | epoch|batch : 1|70 (0.1%) | molecules: 14200 | mols/sec: 524.56 | train_loss: 1.7153\n",
      "TRAIN | elapsed: 0:00:30 | epoch|batch : 1|80 (0.1%) | molecules: 16200 | mols/sec: 538.60 | train_loss: 1.6771\n",
      "TRAIN | elapsed: 0:00:33 | epoch|batch : 1|90 (0.2%) | molecules: 18200 | mols/sec: 549.86 | train_loss: 1.6147\n",
      "TRAIN | elapsed: 0:00:36 | epoch|batch : 1|100 (0.2%) | molecules: 20200 | mols/sec: 559.10 | train_loss: 1.5636\n",
      "TRAIN | elapsed: 0:00:39 | epoch|batch : 1|110 (0.2%) | molecules: 22200 | mols/sec: 566.95 | train_loss: 1.5143\n",
      "TRAIN | elapsed: 0:00:42 | epoch|batch : 1|120 (0.2%) | molecules: 24200 | mols/sec: 574.23 | train_loss: 1.4529\n",
      "TRAIN | elapsed: 0:00:45 | epoch|batch : 1|130 (0.2%) | molecules: 26200 | mols/sec: 580.49 | train_loss: 1.3647\n",
      "TRAIN | elapsed: 0:00:48 | epoch|batch : 1|140 (0.2%) | molecules: 28200 | mols/sec: 586.05 | train_loss: 1.2831\n",
      "TRAIN | elapsed: 0:00:51 | epoch|batch : 1|150 (0.3%) | molecules: 30200 | mols/sec: 588.34 | train_loss: 1.2027\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-02f9271f060d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[0;31m#         valid_list = f.readlines()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraining_set\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidation_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     24\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'All done, your trained model is in {args.output_dir}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/smiles_rnn_distribution_learner.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, data_path, training_set, validation_set)\u001b[0m\n\u001b[1;32m    108\u001b[0m                     \u001b[0mprint_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_every\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m                     \u001b[0mvalid_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalid_every\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 110\u001b[0;31m                     n_epochs=self.n_epochs)\n\u001b[0m",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/rnn_trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, training_data, test_data, n_epochs, batch_size, print_every, valid_every, num_workers)\u001b[0m\n\u001b[1;32m     91\u001b[0m         training_round = _ModelTrainingRound(self, training_data, test_data, n_epochs, batch_size, print_every,\n\u001b[1;32m     92\u001b[0m                                              valid_every, num_workers)\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mtraining_round\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/rnn_trainer.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    132\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mepoch_index\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_epochs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_one_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validation_on_final_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/rnn_trainer.py\u001b[0m in \u001b[0;36m_train_one_epoch\u001b[0;34m(self, epoch_index)\u001b[0m\n\u001b[1;32m    157\u001b[0m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    158\u001b[0m             \u001b[0mproperties\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_all\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 159\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_t0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    161\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_train_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_t0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/rnn_trainer.py\u001b[0m in \u001b[0;36m_train_one_batch\u001b[0;34m(self, batch_index, batch, properties, epoch_index, train_t0)\u001b[0m\n\u001b[1;32m    160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    161\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_train_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_t0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_trainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_on_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    164\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munprocessed_train_losses\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/work/Documents/conditional_discrete_sequence_generation/conditional-molecules/code/gencond/lstm/rnn_trainer.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[0;34m(self, batch, properties)\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;31m# forward / backward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mproperties\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m         \u001b[0;31m# optimize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m     91\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m         \"\"\"\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/pytorch-py3.6/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     87\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     88\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "trainer = SmilesRnnDistributionLearner(data_set = \"QM9\",\n",
    "                                       graph_logger= graph_logger,\n",
    "                                       output_dir=args.output_dir,\n",
    "                                       n_epochs=args.n_epochs,\n",
    "                                       hidden_size=args.hidden_size,\n",
    "                                       n_layers=args.n_layers,\n",
    "                                       max_len=args.max_len,\n",
    "                                       batch_size=args.batch_size,\n",
    "                                       rnn_dropout=args.rnn_dropout,\n",
    "                                       lr=args.lr,\n",
    "                                       #prop_model=prop_model,\n",
    "                                       valid_every=args.valid_every)\n",
    "\n",
    "#     training_set_file = args.train_data\n",
    "#     validation_set_file = args.valid_data\n",
    "# \n",
    "#     with open(training_set_file) as f:\n",
    "#         train_list = f.readlines()\n",
    "# \n",
    "#     with open(validation_set_file) as f:\n",
    "#         valid_list = f.readlines()\n",
    "\n",
    "trainer.train(args.data_path, training_set=args.train_data, validation_set = args.valid_data)\n",
    "print(f'All done, your trained model is in {args.output_dir}')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
