{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning Parameters\n",
    "\n",
    "**Note:** you may `unzip` the file `data/processed.zip` and skip this notebook.\n",
    "\n",
    "**Note:** Some of this code inlcuding all of `sepsisSimDiabetes/` directory is borrowed from [Counterfactual Off-Policy Evaluation with Gumbel-Max Structural Causal Models](https://github.com/clinicalml/gumbel-max-scm). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is used to learn the followings that are needed for the main simulation:\n",
    "\n",
    "1. `optimal_policy_st`: soft opitmal policy. With 0.95 probability takes the action recommended by the optimal policy, and with 0.05 probability picks randomly from other actions.\n",
    "2. `mixed_policy`: This policy is used as the behaviour policy from the second step onward in the main simulation. It is a mixture of two policies:\n",
    "    * 85\\% of `optimal_policy_st`\n",
    "    * 15\\% of a sub-optimal policy that is similar to the optimal but the vasopressors action is flipped. It is also softened with 0.05 probability of a random action.\n",
    "3. `t0_policy`: This policy is used as the bahviour policy at the first timestep, and consist of two policies:\n",
    "    * `with antibiotics`: which is similar to the soft optimal policy except that the probability mass of actions without antibiotics are moved to the corresponding action with antibiotics.\n",
    "    * `without antibiotics`: which is similar to the soft optimal policy except that the probability mass of actions with antibiotics are moved to the corresponding action without antibiotics.\n",
    "4. `value_function`: This is the value function of the optimal policy, learned by value iteration.\n",
    "5. `tx_tr.pkl`: This is a tuple of `(tx, tr)` that are transiton matrix `(nA, nS, nS)` and  reward matrix `(nA, nS, nS)`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Pre**: In this notebook we use the data learned [here](https://github.com/clinicalml/gumbel-max-scm), in [learn_mdp_parameters.ipynb](https://github.com/clinicalml/gumbel-max-scm/blob/master/learn_mdp_parameters.ipynb). You may run their code and use the same data `data/diab_txr_mats-replication.pkl`, or unzip the file `data/oberst_sontag.zip` to obtain the the same pickle file `data/diab_txr_mats-replication.pkl`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Output**: After running this code you should see the following files in directory `data`:\n",
    "\n",
    "`optimal_policy_st.pkl`, `mixed_policy.pkl`, `tx_tr.pkl`, `t0_policy.pkl`, `value_function.pkl`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import tqdm as tqdm\n",
    "from scipy.linalg import block_diag\n",
    "\n",
    "from utils.utils import MatrixMDP\n",
    "from core.sepsisSimDiabetes.State import State\n",
    "from core.sepsisSimDiabetes.Action import Action\n",
    "import core.sepsisSimDiabetes.MDP as simulator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following `config` dictionary contains the necessary information for the rest of the notebook:\n",
    "\n",
    "- `prob_diab`: probability of having diabetes, default = 0.2\n",
    "- `nS`: number of states\n",
    "- `nA`: number of actions\n",
    "- `discount`: MDP discount factor ($\\gamma$)\n",
    "- `epsilon` : probability used for making the soft optimal policy ($\\epsilon$)\n",
    "- `mixture_prob`: percentage of the optimal policy in the mixture of the mixed policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {'prob_diab': 0.2, 'nS': State.NUM_FULL_STATES, \n",
    "          'nA': Action.NUM_ACTIONS_TOTAL, 'discount': 0.99,\n",
    "         'epsilon': 0.05, 'mixture_prob': 0.85}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/diab_txr_mats-replication.pkl\", \"rb\") as f:\n",
    "    mdict = pickle.load(f)\n",
    "\n",
    "tx_mat = mdict[\"tx_mat\"]\n",
    "r_mat = mdict[\"r_mat\"]\n",
    "p_mixture = np.array([1 - config['prob_diab'], config['prob_diab']])\n",
    "\n",
    "tx_mat_full = np.zeros((config['nA'], config['nS'], config['nS']))\n",
    "r_mat_full = np.zeros((config['nA'], config['nS'], config['nS']))\n",
    "\n",
    "for a in range(config['nA']):\n",
    "    tx_mat_full[a, ...] = block_diag(tx_mat[0, a, ...], tx_mat[1, a,...])\n",
    "    r_mat_full[a, ...] = block_diag(r_mat[0, a, ...], r_mat[1, a, ...])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tx_mat_full` and `r_mat_full` does not encode the terminal state, the snippet code below modifies them to have an extra (terminal) state. After a transition to the terminal state, the agent can only transition to the same state, with probability 1 and reward 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add one extra dimension for the new state\n",
    "tx = np.zeros((config['nA'], config['nS'] + 1, config['nS'] + 1))\n",
    "tx[:, :config['nS'], :config['nS']] = np.copy(tx_mat_full)\n",
    "\n",
    "tr = np.zeros((config['nA'], config['nS'] + 1, config['nS'] + 1))\n",
    "tr[:, :config['nS'], :config['nS']] = np.copy(r_mat_full)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We detect a transition `(a,s,s')` should lead to the terminal state if `R(a,s,s') = -1 or 1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s0 in range(config['nS'] + 1):\n",
    "    for a in range(config['nA']):\n",
    "        for s1 in range(config['nS'] + 1):\n",
    "            #  any where R(a, s0, s1) == 1/ -1\n",
    "            if tr[a, s0, s1] == 1 or tr[a, s0, s1] == -1:\n",
    "                tx[:, s1, :] = 0 # transition probability to any othre state it's zero\n",
    "                tx[:, s1, config['nS']] = 1 # Transition to terminal\n",
    "# Reward is 0 at the terminal state\n",
    "# Terminal state transitions to itself with prob 1.0 \n",
    "tx[:,  config['nS'],  config['nS']] = 1\n",
    "tr[:,  config['nS'],  config['nS']] = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`tx_tr.pkl` is a tuple of `(transition matrix, reward matrix)`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/tx_tr.pkl', 'wb') as f:\n",
    "    pickle.dump((tx, tr), f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimal Policy and Value function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First build an MDP with transition `tx` and reward `tr`, using the `MatrixMDP` class. This step requires `mdptoolboxSrc` in the `utils` folder. Then, perform policy iteration to get the optimal policy and value iteration to get the value function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "MDP = MatrixMDP(tx, tr)\n",
    "# Policy Iteration for the optimal policy\n",
    "optimal_policy = MDP.policyIteration(discount=config['discount'], \n",
    "                                     eval_type=1).argmax(axis=1)\n",
    "# Value Iteration for the value function\n",
    "V = MDP.valueIteration(discount=config['discount'], epsilon=0.001,\n",
    "                       max_iter=5000)\n",
    "value_function = np.array(V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make the soft optimal policy `optimal_policy_st` by assigning `1-epsilon` to the optimal action, and `epsilon` equally distributed among other actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimal_policy_st = np.zeros((optimal_policy.size, optimal_policy.max()+1))\n",
    "optimal_policy_st[np.arange(optimal_policy.size),optimal_policy] = 1\n",
    "\n",
    "optimal_policy_st[optimal_policy_st == 1] = 1 - config['epsilon']\n",
    "optimal_policy_st[optimal_policy_st == 0] = config['epsilon'] / (config['nA'] - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/optimal_policy_st.pkl', 'wb') as f:\n",
    "    pickle.dump(optimal_policy_st, f)\n",
    "with open('data/value_function.pkl', 'wb') as f:\n",
    "    pickle.dump(value_function, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mixed Policy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`mixed_policy` is a mixture of two policies:\n",
    "1. 85% of `optimal_policy_st`\n",
    "2. 15% of a sub-optimal policy that is similar to the optimal but the vasopressors action is flipped. And then soften with 0.05 probability.\n",
    "\n",
    "First we make the sub-optimal policy by exchanging the probabilities of vassopressor `on` with `off` of the corresponding action, and the other way around. \n",
    "\n",
    "In this encoding, the mapping between such actions is as follows:\n",
    "```python\n",
    "mapping = {1:0, 0:1, 2:3, 3:2, 4:5, 5:4, 6:7, 7:6}\n",
    "```\n",
    "which means for example the corresponding action with a flipped vassopressor is 1 for action 0 and vice versa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {1:0, 0:1, 2:3, 3:2, 4:5, 5:4, 6:7, 7:6}\n",
    "# change the optimal action\n",
    "mod_policy = np.copy(optimal_policy)\n",
    "for s in range(mod_policy.shape[0]):\n",
    "    mod_policy[s] = mapping[mod_policy[s]]\n",
    "    \n",
    "# make the mod_policy soft\n",
    "mod_policy_st = np.zeros((mod_policy.size, mod_policy.max()+1))\n",
    "mod_policy_st[np.arange(mod_policy.size), mod_policy] = 1\n",
    "\n",
    "mod_policy_st[mod_policy_st == 1] = 1 - config['epsilon']\n",
    "mod_policy_st[mod_policy_st == 0] = config['epsilon'] / (config['nA'] - 1)\n",
    "\n",
    "# mix two policies\n",
    "mixed_policy = config['mixture_prob'] * optimal_policy_st +\\\n",
    "              (1-config['mixture_prob']) * mod_policy_st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/mixed_policy.pkl', 'wb') as f:\n",
    "    pickle.dump(mixed_policy, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First time step policy (t0_policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`t0_policy`: consist of two policies:\n",
    "   1. `with antibiotics`: which is similar to the soft optimal policy except that the probability mass of actions without antibiotics are moved to the corresponding action with antibiotics.\n",
    "   2. `without antibiotics`: which is similar to the soft optimal policy except that the probability mass of actions with antibiotics are moved to the corresponding action without antibiotics.\n",
    "    \n",
    "To do the mapping, first note that the corresponding actions of without antibiotics: `(0, 1, 2, 3)` are `(4, 5, 6, 7)`  with antibiotics.\n",
    "\n",
    "`t0_policy` has dimension `(2, nS, nA)` where the `(0, :, :)` corresponds to the `with antibiotics` policy and `(1, :, :)` corresponds to the `without antibiotics` policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0_policy = np.zeros((2, config['nS'] + 1, config['nA']))\n",
    "# note that, nS+1 is for one extra terminal states\n",
    "\n",
    "t0_policy[0, :, :] = np.copy(optimal_policy_st)\n",
    "t0_policy[1, :, :] = np.copy(optimal_policy_st)\n",
    "\n",
    "# With antibotics\n",
    "t0_policy[0, :, [7,6,5,4]] = t0_policy[0, :, [7,6,5,4]] +\\\n",
    "                             t0_policy[0, :, [3,2,1,0]]\n",
    "t0_policy[0, :, [3,2,1,0]] = 0\n",
    "# Without antibiotics\n",
    "t0_policy[1, :, [3,2,1,0]] = t0_policy[1, :, [7,6,5,4]] +\\\n",
    "                             t0_policy[1, :, [3,2,1,0]]\n",
    "t0_policy[1, :, [7,6,5,4]] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/t0_policy.pkl', 'wb') as f:\n",
    "    pickle.dump(t0_policy, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
