{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How well does a conditional LSTM generate molecules which have desired characteristics?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '..')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from rdkit import Chem\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "import pickle, gzip, torch\n",
    "import numpy as np\n",
    "\n",
    "import warnings\n",
    "\n",
    "from rnn_utils import load_rnn_model\n",
    "import action_sampler, rnn_sampler\n",
    "from gencond.properties import PROPERTIES, TYPES\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training smiles num: 113885\n",
      "validation smiles num: (10000,)\n",
      "test smiles num: (10000,)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "'''\n",
    "model_path = './output/ZINC/batch_size_60_2/model_90_0.213.pt' \n",
    "#model_path = '../our_model/output/ZINC/batch_size_60_2/model_72_0.219.pt'\n",
    "# zinc\n",
    "test_smiles, test_prop = np.load(\"../../data/ZINC/ZINC_clean_smi_test_smile.npz\", allow_pickle=True).values()\n",
    "train_smiles, train_prop = np.load(\"../../data/ZINC/ZINC_clean_smi_train_smile.npz\", allow_pickle=True).values()\n",
    "train_smiles = train_smiles.tolist()\n",
    "\n",
    "\n",
    "def normalize(prop):\n",
    "    train_smiles, train_prop = np.load(\"../../data/ZINC/ZINC_clean_smi_train_smile.npz\",allow_pickle=True).values()\n",
    "    mean = np.mean(train_prop, axis = 0)\n",
    "    std = np.std(train_prop, axis = 0)\n",
    "    return (prop - mean) / std\n",
    "\n",
    "#validation_smiles = train_smiles[:10000]\n",
    "#def normalize_augmented(prop):\n",
    "#    _, train_prop = np.load(\"../../data/QM9/smiles_data_augmentation.npz\", allow_pickle=True).values()\n",
    "#    mean = np.mean(train_prop, axis = 0)\n",
    "#    std = np.std(train_prop,axis =0)\n",
    "#    return (prop - mean)/ std\n",
    "\n",
    "'''\n",
    "\n",
    "\n",
    "#QM9\n",
    "model_path = './output/QM9/2020_1_17/model_79_0.167.pt' \n",
    "test_smiles, test_prop = np.load(\"../../data/QM9/QM9_clean_smi_test_smile.npz\", allow_pickle=True).values()\n",
    "train_smiles, train_prop = np.load(\"../../data/QM9/QM9_clean_smi_train_smile.npz\", allow_pickle=True).values()\n",
    "val_smiles, val_prop = train_smiles[:10000],train_prop[:10000,:]\n",
    "train_smiles, train_prop = train_smiles[10000:],train_prop[10000,:]\n",
    "train_smiles = train_smiles.tolist()\n",
    "print('training smiles num:', len(train_smiles))\n",
    "print('validation smiles num:', val_smiles.shape)\n",
    "print('test smiles num:', test_smiles.shape)\n",
    "\n",
    "\n",
    "def normalize(prop):\n",
    "    train_smiles, train_prop = np.load(\"../../data/QM9/QM9_clean_smi_train_smile.npz\",allow_pickle=True).values()\n",
    "    mean = np.mean(train_prop, axis = 0)\n",
    "    std = np.std(train_prop, axis = 0)\n",
    "    return (prop - mean) / std\n",
    "\n",
    "'''\n",
    "## load gucamol train smiles\n",
    "#def normalize(prop):\n",
    "    #mean, std = np.load('../../data/Guacamol/normalizer.npy',allow_pickle=True)\n",
    "    #return (prop - mean) / std\n",
    "\n",
    "def normalize(prop):    \n",
    "    s, y = np.load(\"../../data/Guacamol/train_props.npz\").values()\n",
    "    mean = np.mean(y, axis = 0)\n",
    "    std = np.std(y, axis = 0)\n",
    "    return (prop - mean)/std\n",
    "\n",
    "train_smiles = []\n",
    "for row in open(\"../../data/Guacamol/guacamol_v1_train.smiles\"):\n",
    "    row = row.strip()\n",
    "    train_smiles.append(row)   \n",
    "    \n",
    "#guacamol\n",
    "test_smiles, test_prop = np.load(\"../../data/Guacamol/test_props.npz\").values()\n",
    "train_smile, train_prop = np.load(\"../../data/Guacamol/train_props.npz\").values()\n",
    "\n",
    "model_path = '../our_model_icml/output/Guacamol/continued/batch_size_100/model_12_0.140.pt' \n",
    "#model_path = './output/Guacamol/baobab/model_91_0.132.pt'\n",
    "print('training smiles num:', len(train_smiles))\n",
    "#print('validation smiles num:', val_smiles.shape)\n",
    "print('test smiles num:', test_smiles.shape)   \n",
    "'''\n",
    "model_def = Path(model_path).with_suffix('.json')\n",
    "model = load_rnn_model(model_def, model_path, device='cpu', copy_to_cpu=True)\n",
    "sampler = rnn_sampler.ConditionalSmilesRnnSampler(device='cpu')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test on random subset of 1000 molecules in test set\n",
    "\n",
    "For each of the 1000 molecules, compute their properties, and then sample from the LSTM 10 SMILES strings conditioned on these property values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "113885"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#np.random.seed(42)\n",
    "subset = np.random.permutation(len(test_smiles)) #[:1000]\n",
    "#subset = subset[:10000]\n",
    "#subset = np.random.permutation(1000)\n",
    "sample_size = 10\n",
    "test_smiles = test_smiles[subset]\n",
    "test_properties = test_prop[subset]\n",
    "generated = sampler.sample(model, normalize(test_properties), sample_size)\n",
    "#generated = sampler.sample(model, prior.transform(test_properties), sample_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_smiles.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##for figure2 \n",
    "### to generate multiple samples for one target proop\n",
    "np.random.seed(1)\n",
    "#subset = np.random.permutation(len(test_smiles)) #[:1000]\n",
    "#subset = np.random.permutation(10000)\n",
    "k = 11\n",
    "sample_size = 10\n",
    "target_smiles = test_smiles[k:k+1]\n",
    "target_properties = test_prop[k:k+1]\n",
    "generated = sampler.sample(model, normalize(target_properties), sample_size)\n",
    "#generated = sampler.sample(model, prior.transform(test_properties), sample_size)\n",
    "\n",
    "prop = []\n",
    "canonical = []\n",
    "for i in range(len(generated[0])):\n",
    "    sample = generated[0][i]\n",
    "    mol = Chem.MolFromSmiles(sample)\n",
    "    try:\n",
    "        prop.append([p(mol) for p in PROPERTIES.values()])\n",
    "        canonical.append(Chem.MolToSmiles(mol))\n",
    "    except:\n",
    "            pass # m    \n",
    "\n",
    "for i in range(len(canonical)):\n",
    "    print(i)\n",
    "    plt.scatter(np.arange(9), np.array(prop[i]).round(3).flatten()) #alpha=1, s =80)\n",
    "    \n",
    "plt.scatter(np.arange(9), target_properties.round(3).flatten(),color='red',alpha=1)\n",
    "plt.xlabel('property dimension index')\n",
    "plt.ylabel('property value')  \n",
    "\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "\n",
    "sampled_mols = [Chem.MolFromSmiles(canonical[i]) for i in set(np.arange(len(canonical)))]\n",
    "Draw.MolsToGridImage(sampled_mols, molsPerRow=5, subImgSize=(300, 300))\n",
    "\n",
    "\n",
    "m = Chem.MolFromSmiles(target_smiles[0])\n",
    "Draw.MolToImage(m)\n",
    "len(np.unique(np.array(canonical)))\n",
    "train_mol = list(set(train_smiles).intersection(set(canonical))) # remove the validation set\n",
    "sampled_mols = [Chem.MolFromSmiles(train_mol[i]) for i in set(np.arange(len(train_mol)))]\n",
    "Draw.MolsToGridImage(sampled_mols, molsPerRow=5, subImgSize=(300, 300))n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check validity, novelty and unicity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "validity\n",
      "unicity\n",
      "noelty\n",
      "0.9642\n",
      "0.558276\n",
      "0.645831\n"
     ]
    }
   ],
   "source": [
    "from itertools import chain\n",
    "valid_mol = []\n",
    "num_valid = 0\n",
    "for row in range(generated.shape[0]):\n",
    "    valid = []\n",
    "    for samples in generated[row]:\n",
    "        mol = Chem.MolFromSmiles(samples)\n",
    "        try:\n",
    "            valid.append(Chem.MolToSmiles(mol))\n",
    "            num_valid = num_valid + 1\n",
    "            \n",
    "        except:\n",
    "            pass # molecule invalid\n",
    "    valid_mol.append(valid)    \n",
    "validity_ = num_valid / (len(generated) * generated[0].shape[0])\n",
    "print('validity')\n",
    "print('unicity')\n",
    "print('noelty')                                  \n",
    "\n",
    "print(np.around(validity_, decimals=6))\n",
    "str_list = [x for x in valid_mol if x != []]\n",
    "valid_mol_list = list(chain.from_iterable(str_list))\n",
    "num_unique = len(np.unique(np.array(valid_mol_list), return_index=True)[1])\n",
    "print(np.around(num_unique / num_valid, decimals=6))\n",
    "\n",
    "#common = list(set(train_smiles.intersection(valid_mol_list)))\n",
    "common = list(set(train_smiles).intersection((valid_mol_list)))  # remove the validation set\n",
    "novelty_= (num_valid - len(common)) / num_valid \n",
    "print(np.around(novelty_, decimals = 6))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute the properties of these generated smiles strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulated = []\n",
    "smiles_matches = []\n",
    "index = []\n",
    "for row in range(generated.shape[0]):\n",
    "    prop = []\n",
    "    canonical = []\n",
    "    for sample in generated[row]:\n",
    "        mol = Chem.MolFromSmiles(sample)\n",
    "        try:\n",
    "            prop.append([p(mol) for p in PROPERTIES.values()])\n",
    "            canonical.append(Chem.MolToSmiles(mol))\n",
    "        except:\n",
    "            pass # molecule invalid\n",
    "        \n",
    "    if prop !=[]:\n",
    "        index.append(row)\n",
    "        \n",
    "    simulated.append(np.array(prop))\n",
    "    smiles_matches.append(np.sum([test_smiles == s for s in canonical]))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First, check molecule validity, and confirm that we aren't just always recreating the same SMILES string we started with…"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.996233 0.998001 0.994686 0.971792 0.996827 0.995805 0.993406 0.972188\n",
      " 0.995973]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:2: RuntimeWarning: Mean of empty slice.\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "if simulated[0].shape[0]>1:\n",
    "    mean_prop= [item.mean(axis = 0 ) for item in simulated]\n",
    "\n",
    "else:\n",
    "    mean_prop = simulated\n",
    "    \n",
    "valid_smiles_prop = np.array(mean_prop)[index]\n",
    "valid_smiles_prop = np.vstack(valid_smiles_prop)\n",
    "test_prop = test_properties[index]\n",
    "corr_coef = []\n",
    "for i in range(9):\n",
    "    corr_coef.append(np.corrcoef(valid_smiles_prop[:,i], test_prop[:,i])[0,1])\n",
    "    \n",
    "print(np.around(corr_coef, decimals=6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error: [3.31287891e-02 1.87653686e-03 6.16217054e-02 9.80934375e-03\n",
      " 8.81407379e-01 2.76334188e+00 3.18890006e-01 1.89471328e-03\n",
      " 2.82933055e-02]\n"
     ]
    }
   ],
   "source": [
    "MAE = np.mean(np.absolute(valid_smiles_prop - test_prop), axis = 0) \n",
    "print('Mean absolute error:', MAE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9.18893099e-03, 6.53551246e-04, 1.09930722e-02, 3.22673092e-04,\n",
       "       2.82087648e+00, 1.95609971e+01, 7.88064811e-01, 2.47439008e-03,\n",
       "       1.18536828e-02])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MSE_per_prop = np.mean(np.power(valid_smiles_prop - test_prop, 2),axis = 0) \n",
    "MSE_per_prop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE\n",
      "2.578381\n"
     ]
    }
   ],
   "source": [
    "MSE = np.mean(np.power(valid_smiles_prop - test_prop, 2)) \n",
    "print('MSE')\n",
    "print(np.around(MSE, decimals = 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE\n",
      "0.455585\n"
     ]
    }
   ],
   "source": [
    "MAE = np.mean(np.absolute(valid_smiles_prop - test_prop)) \n",
    "print('MAE')\n",
    "print(np.around(MAE, decimals = 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(simulated)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot correlation between input to conditional LSTM model, and properties of generated molecules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings('ignore')\n",
    "titles = list(PROPERTIES.keys())\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "for row in range(1000):\n",
    "    prop = simulated[row]\n",
    "    for i in range(prop.shape[-1]):\n",
    "        plt.subplot(3,3,i+1)\n",
    "        jitter = np.random.randn(prop.shape[0], 2) * (0.0 if TYPES[titles[i]] == float else 0.05)\n",
    "        plt.plot(test_properties[row,i]+jitter[:,0], \n",
    "                 prop[:,i] + jitter[:,1], \n",
    "                 marker='.', linewidth=0, color='#3333cc', alpha=0.5)\n",
    "\n",
    "for i in range(9):\n",
    "    plt.subplot(3,3,i+1)\n",
    "    a,b = plt.xlim()\n",
    "    c,d = plt.ylim()\n",
    "    lims = min(a,c), max(b,d)\n",
    "    plt.xlim(lims)\n",
    "    plt.ylim(lims)\n",
    "    plt.title(titles[i] + '\\n Correlation coef:\\n ' +  str(corr_coef[i]))\n",
    "    if i >= 6:\n",
    "        plt.xlabel(\"input value\")\n",
    "    if i % 3 == 0:\n",
    "        plt.ylabel(\"output value\")\n",
    "    \n",
    "plt.tight_layout();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_properties.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test set log likelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from rnn_utils import load_smiles_from_list, get_tensor_dataset\n",
    "from smiles_char_dict import SmilesCharDictionary\n",
    "from atalaya import Logger\n",
    "graph_logger = ''\n",
    "\n",
    "#_, train_prop = np.load(\"../../data/QM9/QM9_clean_smi_train_smile.npz\", allow_pickle=True).values()\n",
    "mean = np.mean(train_prop, axis = 0)\n",
    "std = np.std(train_prop,axis =0)\n",
    "\n",
    "\n",
    "sd = SmilesCharDictionary()\n",
    "criterion = torch.nn.CrossEntropyLoss(ignore_index=sd.pad_idx)\n",
    "\n",
    "device = 'cpu'\n",
    "test_s,_ = load_smiles_from_list(test_smiles, test_properties)\n",
    "test_set = get_tensor_dataset(test_s, normalize(test_properties))\n",
    "data_loader = DataLoader(test_set,batch_size=500,shuffle=False,num_workers=1,pin_memory=True)\n",
    "from rnn_trainer import SmilesRnnTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), 0.01)\n",
    "criterion = torch.nn.CrossEntropyLoss(ignore_index=sd.pad_idx)\n",
    "trainer = SmilesRnnTrainer(graph_logger, model=model,\n",
    "                                   normalizer_mean = mean,\n",
    "                                   normalizer_std = std,\n",
    "                                   criteria=[criterion],\n",
    "                                   optimizer=optimizer,\n",
    "                                   device=device,\n",
    "                                   log_dir=None)\n",
    "\n",
    "_logp = trainer.validate(data_loader, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(np.round(_logp, decimals = 6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_smiles)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# conditional generation performance over multiple runs on the full test set\n",
    "### use L2 and L1 norm on the (y-y')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## mean mse\n",
    "test_smiles, test_prop = np.load(\"../../data/Guacamol/test_props.npz\").values()\n",
    "subset = np.random.permutation(len(test_smiles))\n",
    "subset = subset[:10000]\n",
    "test_smiles = test_smiles[subset]\n",
    "test_properties = test_prop[subset,:]\n",
    "MSE = []\n",
    "MES_per_dim = []\n",
    "MAE = []\n",
    "MAE_per_dim = []\n",
    "for i in range(10):\n",
    "    #test_smiles, test_prop = np.load(\"../../data/QM9/QM9_clean_smi_test_smile.npz\", allow_pickle=True).values()\n",
    "    #test_smiles, test_prop = np.load(\"../../data/Guacamol/test_props.npz\").values()\n",
    "    np.random.seed(1)\n",
    "    #subset = np.random.permutation(len(test_smiles))\n",
    "    sample_size = 1\n",
    "    \n",
    "    generated = sampler.sample(model, normalize(test_properties), sample_size)\n",
    "\n",
    "    simulated = []\n",
    "    index = []\n",
    "    for row in range(generated.shape[0]):\n",
    "        prop = []\n",
    "        canonical = []\n",
    "        for sample in generated[row]:\n",
    "            mol = Chem.MolFromSmiles(sample)\n",
    "            try:\n",
    "                prop.append([p(mol) for p in PROPERTIES.values()])\n",
    "                canonical.append(Chem.MolToSmiles(mol))\n",
    "            except:\n",
    "                pass # molecule invalid\n",
    "\n",
    "        if prop !=[]:\n",
    "            index.append(row)\n",
    "\n",
    "        simulated.append(np.array(prop))\n",
    "\n",
    "    if simulated[0].shape[0]>1:\n",
    "        mean_prop= [item.mean(axis = 0 ) for item in simulated]\n",
    "\n",
    "    else:\n",
    "        mean_prop = simulated\n",
    "\n",
    "    valid_smiles_prop = np.array(mean_prop)[index]\n",
    "    valid_smiles_prop = np.vstack(valid_smiles_prop)\n",
    "    test_prop = test_properties[index]   \n",
    "\n",
    "    mse = np.mean(np.power(valid_smiles_prop - test_prop, 2)) \n",
    "    mae = np.mean(np.absolute(valid_smiles_prop - test_prop)) \n",
    "    print('Mean square error:', mse)\n",
    "    per_prop = np.mean(np.power(valid_smiles_prop - test_prop, 2),axis = 0)\n",
    "    mae_per_prop = np.mean(np.absolute(valid_smiles_prop - test_prop),axis = 0)\n",
    "    \n",
    "    MSE.append(mse)\n",
    "    MES_per_dim.append(per_prop)\n",
    "    MAE.append(mae)\n",
    "    MAE_per_dim.append(mae_per_prop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Mean square error')\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\"%(np.array(MSE).mean(), np.array(MSE).std()))\n",
    "print('Mean absolute error')\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\"%(np.array(MAE).mean(), np.array(MAE).std()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Mean square error')\n",
    "mean_per = np.array(MES_per_dim).mean(axis = 0)\n",
    "std_per = np.array(MES_per_dim).std(axis = 0)\n",
    "for i in range(9):\n",
    "    print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(mean_per[i],std_per[i]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Mean absolute error')\n",
    "MAE_mean_per = np.array(MAE_per_dim).mean(axis = 0)\n",
    "MAE_std_per = np.array(MAE_per_dim).std(axis = 0)\n",
    "for i in range(9):\n",
    "    print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(MAE_mean_per[i],MAE_std_per[i]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# statistics over multiple property"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#lstm\n",
    "val_n = np.array([0.9619,0.9642,0.9701])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(val_n.mean(),val_n.std()))\n",
    "\n",
    "unicity_n = np.array([0.9667,0.9694,0.96618])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(unicity_n.mean(),unicity_n.std()))\n",
    "\n",
    "Nov_n = np.array([0.3660,0.3581,0.37315])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(Nov_n.mean(),Nov_n.std()))\n",
    "\n",
    "_logp_n = np.array([0.21800,0.21447,0.21329])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(_logp_n.mean(),_logp_n.std()))\n",
    "\n",
    "mse_n = np.array([10.337,9.7910,8.9395])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(mse_n.mean(),mse_n.std()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## our\n",
    "val_n = np.array([0.9886,0.9863,0.9909])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(val_n.mean(),val_n.std()))\n",
    "\n",
    "unicity_n = np.array([0.9629,0.9634,0.9641])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(unicity_n.mean(),unicity_n.std()))\n",
    "\n",
    "Nov_n = np.array([0.2605,0.2676,0.2345])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(Nov_n.mean(),Nov_n.std()))\n",
    "\n",
    "_logp_n = np.array([0.2357,0.2335,0.2437])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(_logp_n.mean(),_logp_n.std()))\n",
    "\n",
    "mse_n = np.array([7.0603,8.4115,6.4833])\n",
    "print(\"%.4f\" u\"\\u00B1\" \"%.4f\" %(mse_n.mean(),mse_n.std()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# lets see if the hidden state represents well the mol structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rnn_utils import rnn_start_token_vector\n",
    "from torch.distributions import Categorical, Distribution\n",
    "import torch.nn.functional as F\n",
    "batch_size = 100\n",
    "max_seq_length = 101\n",
    "device = 'cpu'\n",
    "hidden = model.init_hidden(batch_size, device)\n",
    "inp = rnn_start_token_vector(batch_size, device)\n",
    "\n",
    "actions = torch.zeros((batch_size, max_seq_length), dtype=torch.long).to(device)\n",
    "distribution_cls= None\n",
    "distribution_cls = Categorical if distribution_cls is None else distribution_cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "properties = torch.FloatTensor(test_properties)\n",
    "for char in range(max_seq_length):\n",
    "    output, hidden = model(inp, properties, hidden)\n",
    "\n",
    "    prob = F.softmax(output, dim=2)\n",
    "    distribution = distribution_cls(probs=prob)\n",
    "    action = distribution.sample()\n",
    "\n",
    "    actions[:, char] = action.squeeze()\n",
    "\n",
    "    inp = action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inp = rnn_start_token_vector(batch_size, device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len = 100\n",
    "device = 'cpu'\n",
    "from rnn_utils import get_tensor_dataset, load_smiles_and_properties, set_random_seed\n",
    "training_set=\"../../data/QM9/QM9_clean_smi_train_smile.npz\"\n",
    "train_seqs, train_prop = load_smiles_and_properties(training_set, max_len)\n",
    "sample_indexs = np.arange(train_seqs.shape[0])\n",
    "\n",
    "train_x, train_y = train_seqs[10000:,:], train_prop[10000:,:]\n",
    "valid_x, valid_y = train_seqs[:10000,:], train_prop[:10000,:]\n",
    "\n",
    "all_y = np.concatenate((train_y, valid_y), axis=0)\n",
    "mean = np.mean(all_y, axis = 0)\n",
    "std = np.std(all_y, axis = 0)\n",
    "\n",
    "#np.save(args.data_path + '/normalizer.py', [mean, std])\n",
    "\n",
    "train_y = (train_y - mean ) / std\n",
    "valid_y = (valid_y - mean) / std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1000\n",
    "inp = torch.from_numpy(train_x[:batch_size,:-1]).long()\n",
    "properties  = train_y[:batch_size,:]\n",
    "properties = torch.FloatTensor(properties)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden = model.init_hidden(inp.size(0), device)\n",
    "output, hidden = model(inp, properties, hidden) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden[1][2].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden[1][2].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles_list, properties = np.load(training_set, allow_pickle=True).values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles_list = smiles_list[:batch_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "from rdkit.Chem import Draw\n",
    "from rdkit.Chem import AllChem\n",
    "from rdkit import DataStructs\n",
    "tanimoto_sim = []\n",
    "for i in range(batch_size):\n",
    "    fp_A = AllChem.GetMorganFingerprint(Chem.MolFromSmiles(smiles_list[0]), 2)\n",
    "    fp_B =  AllChem.GetMorganFingerprint(Chem.MolFromSmiles(smiles_list[i]), 2)\n",
    "    sim = DataStructs.TanimotoSimilarity(fp_A,fp_B)\n",
    "    tanimoto_sim.append(sim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(tanimoto_sim).argsort()[-20:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fpA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Draw.MolToImage(Chem.MolFromSmiles(smiles_list[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist = []\n",
    "target = hidden[1][2][0,:]\n",
    "for i in range(batch_size):\n",
    "    l_2_dist = torch.pow((hidden[1][2])[i,:] -  target, 2).mean()\n",
    "    dist.append(l_2_dist)\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(dist).argsort()[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "properties = torch.FloatTensor(properties)\n",
    "for char in range(max_seq_length):\n",
    "    output, hidden = model(inp, properties, hidden)\n",
    "\n",
    "    prob = F.softmax(output, dim=2)\n",
    "    distribution = distribution_cls(probs=prob)\n",
    "    action = distribution.sample()\n",
    "\n",
    "    actions[:, char] = action.squeeze()\n",
    "    inp = action"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
