{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 1\n",
    "\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "\n",
    "scriptpath = \"binarygridsearch.py\"\n",
    "# Do the import\n",
    "sys.path.append(os.path.abspath(scriptpath))\n",
    "\n",
    "# Do the import\n",
    "import binarygridsearch as bgs\n",
    "%aimport binarygridsearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import uniform, randint\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import RandomizedSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Define a function that just returns accuracy\n",
    "def getForestAccuracy(X, y, metric, kwargs):\n",
    "    clf = RandomForestClassifier(**kwargs)\n",
    "    clf.fit(X, y)\n",
    "    y_pred = clf.oob_decision_function_[:, 1]\n",
    "    return metric(y, y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_breast_cancer\n",
    "data = load_breast_cancer()\n",
    "X, y = data.data, data.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rfArgs = {\"random_state\": 0,\n",
    "          \"n_jobs\": -1,\n",
    "          \"class_weight\": \"balanced\",\n",
    "         \"n_estimators\": 18,\n",
    "         \"oob_score\": True}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = [[\"max_depth\", 0, 1, 32],\n",
    "                  [\"min_samples_split\", 2, 0.01, 0.1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'max_depth': 32, 'min_samples_split': 0.1}\n",
      "0.9837019713545796\n",
      "CPU times: user 97.1 ms, sys: 48.3 ms, total: 145 ms\n",
      "Wall time: 1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "dct = bgs.binarySearchParamsParallel(X, \n",
    "                         y,  \n",
    "                         getForestAccuracy,  \n",
    "                         rfArgs, \n",
    "                         roc_auc_score, \n",
    "                         hyperparameters)\n",
    "\n",
    "print(dct['values'])\n",
    "print(dct['score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dct[\"n_iterations\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_dist= {\"max_depth\": randint(1, 32) ,\n",
    "            \"min_samples_split\": uniform(loc=0.01, scale=0.09) }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'max_depth': 20, 'min_samples_split': 0.06813047017599905}\n",
      "0.9893966403611371\n",
      "CPU times: user 267 ms, sys: 42.3 ms, total: 309 ms\n",
      "Wall time: 3.2 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "RF = RandomForestClassifier(**rfArgs)\n",
    "clf = RandomizedSearchCV(RF, \n",
    "                         param_dist, \n",
    "                         random_state=0, \n",
    "                         n_iter=10, \n",
    "                         cv=3, \n",
    "                         verbose=0,\n",
    "                         scoring='roc_auc',\n",
    "                         n_jobs=-1,\n",
    "                        )\n",
    "best_model = clf.fit(X, y)\n",
    "print(best_model.best_params_)\n",
    "print(best_model.best_score_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = [[\"max_depth\", 0, 1, 32],\n",
    "                  [\"min_samples_split\", 3, 0.03, 0.1]]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'max_depth': 32, 'min_samples_split': 0.1}\n",
      "0.9837019713545796\n",
      "CPU times: user 84.8 ms, sys: 54.9 ms, total: 140 ms\n",
      "Wall time: 1.4 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "dct = bgs.binarySearchParamsParallel(X, \n",
    "                         y,  \n",
    "                         getForestAccuracy,  \n",
    "                         rfArgs, \n",
    "                         roc_auc_score, \n",
    "                         hyperparameters)\n",
    "\n",
    "print(dct['values'])\n",
    "print(dct['score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = [['max_depth', 0, 1, 100],\n",
    "                  [\"min_samples_split\", 2, 0.01, 0.1],\n",
    "                  [\"min_samples_leaf\", 2, 0.01, 0.1],]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'max_depth': 100, 'min_samples_split': 0.1, 'min_samples_leaf': 0.030000000000000002}\n",
      "0.9860208234237091\n"
     ]
    }
   ],
   "source": [
    "dct = bgs.binarySearchParamsParallel(X, \n",
    "                         y,  \n",
    "                         getForestAccuracy,  \n",
    "                         rfArgs, \n",
    "                         roc_auc_score, \n",
    "                         hyperparameters)\n",
    "\n",
    "\n",
    "print(dct['values'])\n",
    "print(dct['score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'max_depth': 100, 'min_samples_split': 0.1, 'min_samples_leaf': 0.01}\n",
      "0.9853535753924211\n"
     ]
    }
   ],
   "source": [
    "hyperparameters = [['max_depth', 0, 1, 100],\n",
    "                  [\"min_samples_split\", 3, 0.01, 0.1],\n",
    "                  [\"min_samples_leaf\", 3, 0.01, 0.1],]\n",
    "\n",
    "dct = bgs.binarySearchParamsParallel(X, \n",
    "                         y,  \n",
    "                         getForestAccuracy,  \n",
    "                         rfArgs, \n",
    "                         roc_auc_score, \n",
    "                         hyperparameters)\n",
    "\n",
    "\n",
    "print(dct['values'])\n",
    "print(dct['score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
