{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantizing RNN Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we show how to quantize recurrent models.  \n",
    "Using a pretrained model `model.RNNModel`, we convert the built-in pytorch implementation of LSTM to our own, modular implementation.  \n",
    "The pretrained model was generated with:  \n",
    "```time python3 main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --tied --wd=1e-6```  \n",
    "The reason we replace the LSTM that is because the inner operations in the pytorch implementation are not accessible to us, but we still want to quantize these operations. <br />\n",
    "Afterwards we can try different techniques to quantize the whole model.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import DistillerRNNModel, RNNModel\n",
    "from data import Corpus\n",
    "import torch\n",
    "from torch import nn\n",
    "import distiller\n",
    "from distiller.modules import DistillerLSTM as LSTM\n",
    "from tqdm import tqdm # for pretty progress bar\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(action='default', module='distiller.quantization')\n",
    "warnings.filterwarnings(action='default', module='distiller.quantization.range_linear')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocess the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = Corpus('./data/wikitext-2/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batchify(data, bsz):\n",
    "    # Work out how cleanly we can divide the dataset into bsz parts.\n",
    "    nbatch = data.size(0) // bsz\n",
    "    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
    "    data = data.narrow(0, 0, nbatch * bsz)\n",
    "    # Evenly divide the data across the bsz batches.\n",
    "    data = data.view(bsz, -1).t().contiguous()\n",
    "    return data.to(device)\n",
    "device = 'cuda:0'\n",
    "batch_size = 20\n",
    "eval_batch_size = 10\n",
    "train_data = batchify(corpus.train, batch_size)\n",
    "val_data = batchify(corpus.valid, eval_batch_size)\n",
    "test_data = batchify(corpus.test, eval_batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the model and converting to our own implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RNNModel(\n",
       "  (drop): Dropout(p=0.65)\n",
       "  (encoder): Embedding(33278, 1500)\n",
       "  (rnn): LSTM(1500, 1500, num_layers=2, dropout=0.65)\n",
       "  (decoder): Linear(in_features=1500, out_features=33278, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn_model = torch.load('./checkpoint.pth.tar.best')\n",
    "rnn_model = rnn_model.to(device)\n",
    "rnn_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we convert the pytorch LSTM implementation to our own, by calling `LSTM.from_pytorch_impl`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistillerRNNModel(\n",
       "  (encoder): Embedding(33278, 1500)\n",
       "  (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n",
       "  (decoder): Linear(in_features=1500, out_features=33278, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def manual_model(pytorch_model_: RNNModel):\n",
    "    nlayers, ninp, nhid, ntoken, tie_weights = \\\n",
    "        pytorch_model_.nlayers, \\\n",
    "        pytorch_model_.ninp, \\\n",
    "        pytorch_model_.nhid, \\\n",
    "        pytorch_model_.ntoken, \\\n",
    "        pytorch_model_.tie_weights\n",
    "\n",
    "    model = DistillerRNNModel(nlayers=nlayers, ninp=ninp, nhid=nhid, ntoken=ntoken, tie_weights=tie_weights).to(device)\n",
    "    model.eval()\n",
    "    model.encoder.weight = nn.Parameter(pytorch_model_.encoder.weight.clone().detach())\n",
    "    model.decoder.weight = nn.Parameter(pytorch_model_.decoder.weight.clone().detach())\n",
    "    model.decoder.bias = nn.Parameter(pytorch_model_.decoder.bias.clone().detach())\n",
    "    model.rnn = LSTM.from_pytorch_impl(pytorch_model_.rnn)\n",
    "\n",
    "    return model\n",
    "\n",
    "man_model = manual_model(rnn_model)\n",
    "torch.save(man_model, 'manual.checkpoint.pth.tar')\n",
    "man_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batching the data for evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_len = 35\n",
    "def get_batch(source, i):\n",
    "    seq_len = min(sequence_len, len(source) - 1 - i)\n",
    "    data = source[i:i+seq_len]\n",
    "    target = source[i+1:i+1+seq_len].view(-1)\n",
    "    return data, target\n",
    "\n",
    "hidden = rnn_model.init_hidden(eval_batch_size)\n",
    "data, targets = get_batch(test_data, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check that the convertion has succeeded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max error in y: 0.000006\n"
     ]
    }
   ],
   "source": [
    "rnn_model.eval()\n",
    "man_model.eval()\n",
    "y_t, h_t = rnn_model(data, hidden)\n",
    "y_p, h_p = man_model(data, hidden)\n",
    "\n",
    "print(\"Max error in y: %f\" % (y_t-y_p).abs().max().item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "def repackage_hidden(h):\n",
    "    \"\"\"Wraps hidden states in new Tensors, to detach them from their history.\"\"\"\n",
    "    if isinstance(h, torch.Tensor):\n",
    "        return h.detach()\n",
    "    else:\n",
    "        return tuple(repackage_hidden(v) for v in h)\n",
    "    \n",
    "\n",
    "def evaluate(model, data_source):\n",
    "    # Turn on evaluation mode which disables dropout.\n",
    "    model.eval()\n",
    "    total_loss = 0.\n",
    "    ntokens = len(corpus.dictionary)\n",
    "    hidden = model.init_hidden(eval_batch_size)\n",
    "    with torch.no_grad():\n",
    "        with tqdm(range(0, data_source.size(0), sequence_len)) as t:\n",
    "            # The line below was fixed as per: https://github.com/pytorch/examples/issues/214\n",
    "            for i in t:\n",
    "                data, targets = get_batch(data_source, i)\n",
    "                output, hidden = model(data, hidden)\n",
    "                output_flat = output.view(-1, ntokens)\n",
    "                total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "                hidden = repackage_hidden(hidden)\n",
    "                avg_loss = total_loss / (i + 1)\n",
    "                t.set_postfix((('val_loss', avg_loss), ('ppl', np.exp(avg_loss))))\n",
    "    return total_loss / len(data_source)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantizing the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect activation statistics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model uses activation statistics to determine how big the quantization range is. The bigger the range - the larger the round off error after quantization which leads to accuracy drop.  \n",
    "Our goal is to minimize the range s.t. it contains the absolute most of our data.  \n",
    "After that, we divide the range into chunks of equal size, according to the number of bits, and transform the data according to this scale factor.  \n",
    "Read more on scale factor calculation [in our docs](https://nervanasystems.github.io/distiller/algo_quantization.html).\n",
    "\n",
    "The class `QuantCalibrationStatsCollector` collects the statistics for defining the range $r = max - min$.  \n",
    "\n",
    "Each forward pass, the collector records the values of inputs and outputs, for each layer:\n",
    "- absolute over all batches min, max (stored in `min`, `max`)\n",
    "- average over batches, per batch min, max (stored in `avg_min`, `avg_max`)\n",
    "- mean\n",
    "- std\n",
    "- shape of output tensor  \n",
    "\n",
    "All these values can be used to define the range of quantization, e.g. we can use the absolute `min`, `max` to define the range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from distiller.data_loggers import collect_quant_stats\n",
    "\n",
    "man_model = torch.load('./manual.checkpoint.pth.tar')\n",
    "distiller.utils.assign_layer_fq_names(man_model)\n",
    "collector = QuantCalibrationStatsCollector(man_model)\n",
    "\n",
    "stats_file = './acts_quantization_stats.yaml'\n",
    "\n",
    "if not os.path.isfile(stats_file):\n",
    "    def eval_for_stats(model):\n",
    "        evaluate(man_model, val_data)\n",
    "    collect_quant_stats(man_model, eval_for_stats, save_dir='.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the Model For Quantization\n",
    "  \n",
    "We quantize the model after the training has completed.  \n",
    "Here we check the baseline model perplexity, to have an idea how good the quantization is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [00:21<00:00, 29.07it/s, val_loss=4.47, ppl=87.3]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:    4.46\t|\t ppl:   86.79\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from distiller.quantization import PostTrainLinearQuantizer, LinearQuantMode\n",
    "from copy import deepcopy\n",
    "\n",
    "# Load and evaluate the baseline model.\n",
    "man_model = torch.load('./manual.checkpoint.pth.tar')\n",
    "val_loss = evaluate(man_model, val_data)\n",
    "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we do our magic - __Preparing the model for quantization__.  \n",
    "The quantizer replaces the layers in out model with their quantized versions.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W1030 16:04:58.930041 140055302981376 tensor.py:435] /data2/users/gjacob/work/venvs/distiller/lib/python3.5/site-packages/torch/tensor.py:435: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).\n",
      "  'incorrect results).', category=RuntimeWarning)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Define the quantizer\n",
    "quantizer = PostTrainLinearQuantizer(\n",
    "    deepcopy(man_model),\n",
    "    model_activation_stats=stats_file)\n",
    "\n",
    "# Quantizer magic\n",
    "stats_before_prepare = deepcopy(quantizer.model_activation_stats)\n",
    "dummy_input = (torch.zeros(1,1).to(dtype=torch.long), man_model.init_hidden(1))\n",
    "quantizer.prepare_model(dummy_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Net-Aware Quantization\n",
    "\n",
    "Note that we passes a dummy input to `prepare_model`. This is required for the quantizer to be able to create a graph representation of the model, and to infer the connectivity between the modules.  \n",
    "Understanding the connectivity of the model is required to enable **\"Net-aware quantization\"**. This term (coined in [\\[1\\]](#references), section 3.2.2), means we can achieve better quantization by considering sequences of operations.  \n",
    "In the case of LSTM, we have an element-wise add operation whose output is split into 4 and fed into either Tanh or Sigmoid activations. Both of these ops saturate at relatively small input values - tanh at approximately $|4|$, and sigmoid saturates at approximately $|6|$. This means we can safely clip the output of the element-wise add operation between $[-6,6]$. `PostTrainLinearQuantizer` detects this patterm and modifies the statistics accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stats BEFORE prepare_model:\n",
      "OrderedDict([('min', -15.612003326416016),\n",
      "             ('max', 15.450967788696289),\n",
      "             ('avg_min', -6.393781979404043),\n",
      "             ('avg_max', 5.522592915806375),\n",
      "             ('mean', -0.7798242293363614),\n",
      "             ('std', 1.3385719958875721),\n",
      "             ('shape', '(10, 6000)')])\n",
      "\n",
      "Stats AFTER to prepare_model:\n",
      "OrderedDict([('min', -6.0),\n",
      "             ('max', 6.0),\n",
      "             ('avg_min', -6.0),\n",
      "             ('avg_max', 5.522592915806375),\n",
      "             ('mean', -0.7798242293363614),\n",
      "             ('std', 1.3385719958875721),\n",
      "             ('shape', '(10, 6000)')])\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "pp = pprint.PrettyPrinter(indent=1)\n",
    "print('Stats BEFORE prepare_model:')\n",
    "pp.pprint(stats_before_prepare['rnn.cells.0.eltwiseadd_gate']['output'])\n",
    "\n",
    "print('\\nStats AFTER to prepare_model:')\n",
    "pp.pprint(quantizer.model_activation_stats['rnn.cells.0.eltwiseadd_gate']['output'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note the value for `avg_max` did not change, since it was already below the clipping value of $6.0$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspecting the Quantized Model\n",
    "\n",
    "Let's see how the model has after being prepared for quantization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistillerRNNModel(\n",
       "  (encoder): RangeLinearEmbeddingWrapper(\n",
       "    (wrapped_module): Embedding(33278, 1500)\n",
       "  )\n",
       "  (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n",
       "  (decoder): RangeLinearQuantParamLayerWrapper(\n",
       "    weights_quant_settings=(num_bits=8 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    output_quant_settings=(num_bits=8 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    accum_quant_settings=(num_bits=32 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    requires_quantized_inputs=True\n",
       "      inputs_quant_auto_fallback=True, forced_quant_settings_for_inputs=None\n",
       "    scale_approx_mult_bits=None\n",
       "    preset_activation_stats=True\n",
       "      output_scale=3.656110, output_zero_point=0.000000\n",
       "    weights_scale=123.293175, weights_zero_point=0.000000\n",
       "    (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quantizer.model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note how `encoder` and `decoder` have been replaced with wrapper layers (for the relevant module type), which handle the quantization. The same holds for the internal layers of the `DistillerLSTM` module, which we don't print for brevity sake. To \"peek\" inside the `DistillerLSTM` module, we need to access it directly. As an example, let's take a look at a couple of the internal layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RangeLinearQuantParamLayerWrapper(\n",
      "  weights_quant_settings=(num_bits=8 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
      "  output_quant_settings=(num_bits=8 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
      "  accum_quant_settings=(num_bits=32 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
      "  requires_quantized_inputs=True\n",
      "    inputs_quant_auto_fallback=True, forced_quant_settings_for_inputs=None\n",
      "  scale_approx_mult_bits=None\n",
      "  preset_activation_stats=True\n",
      "    output_scale=17.711266, output_zero_point=0.000000\n",
      "  weights_scale=163.081299, weights_zero_point=0.000000\n",
      "  (wrapped_module): Linear(in_features=1500, out_features=6000, bias=True)\n",
      ")\n",
      "RangeLinearQuantEltwiseAddWrapper(\n",
      "  output_quant_settings=(num_bits=8 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
      "  accum_quant_settings=(num_bits=32 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
      "  requires_quantized_inputs=True\n",
      "    inputs_quant_auto_fallback=True, forced_quant_settings_for_inputs=None\n",
      "  scale_approx_mult_bits=None\n",
      "  preset_activation_stats=True\n",
      "    output_scale=21.166668, output_zero_point=0.000000\n",
      "  (wrapped_module): EltwiseAdd()\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(quantizer.model.rnn.cells[0].fc_gate_x)\n",
    "print(quantizer.model.rnn.cells[0].eltwiseadd_gate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the Quantized Model\n",
    "\n",
    "### Try 1: Initial settings - simple symmetric quantization\n",
    "\n",
    "Finally, let's go ahead and evaluate the quantized model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [03:11<00:00,  3.46it/s, val_loss=4.65, ppl=105]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:    4.65\t|\t ppl:  104.20\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "val_loss = evaluate(quantizer.model.to(device), val_data)\n",
    "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see here, the perplexity has increased much - meaning our quantization has damaged the accuracy of our model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try 2: Assymetric, per-channel\n",
    "\n",
    "Let's try quantizing each channel separately, and making the range of the quantization asymmetric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistillerRNNModel(\n",
       "  (encoder): RangeLinearEmbeddingWrapper(\n",
       "    (wrapped_module): Embedding(33278, 1500)\n",
       "  )\n",
       "  (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n",
       "  (decoder): RangeLinearQuantParamLayerWrapper(\n",
       "    weights_quant_settings=(num_bits=8 ; quant_mode=ASYMMETRIC_SIGNED ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=True)\n",
       "    output_quant_settings=(num_bits=8 ; quant_mode=ASYMMETRIC_SIGNED ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    accum_quant_settings=(num_bits=32 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    requires_quantized_inputs=True\n",
       "      inputs_quant_auto_fallback=True, forced_quant_settings_for_inputs=None\n",
       "    scale_approx_mult_bits=None\n",
       "    preset_activation_stats=True\n",
       "      output_scale=5.024112, output_zero_point=48.000000\n",
       "    weights_scale=PerCh, weights_zero_point=PerCh\n",
       "    (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quantizer = PostTrainLinearQuantizer(\n",
    "    deepcopy(man_model),\n",
    "    model_activation_stats=stats_file,\n",
    "    mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n",
    "    per_channel_wts=True\n",
    ")\n",
    "quantizer.prepare_model(dummy_input)\n",
    "quantizer.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [03:09<00:00,  3.54it/s, val_loss=4.62, ppl=101]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:    4.61\t|\t ppl:  100.45\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "val_loss = evaluate(quantizer.model.to(device), val_data)\n",
    "print('val_loss:%8.2f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A tiny bit better, but still no good."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try 3: Mixed FP16 and INT8\n",
    "\n",
    "Let us try the half precision (aka FP16) version of the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [00:22<00:00, 27.95it/s, val_loss=4.47, ppl=87.3]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss: 4.463559\t|\t ppl:   86.80\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model_fp16 = deepcopy(man_model).half()\n",
    "val_loss = evaluate(model_fp16, val_data)\n",
    "print('val_loss: %8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result is very close to our original model! That means that the roundoff when quantizing linearly to 8-bit integers is what hurts our accuracy.\n",
    "\n",
    "Luckily, `PostTrainLinearQuantizer` supports quantizing some/all layers to FP16 using the `fp16` parameter. In light of what we just saw, and as stated in [\\[2\\]](#References), let's try keeping element-wise operations at FP16, and quantize everything else to 8-bit using the same settings as in try 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1030 16:35:44.686105 140055302981376 range_linear.py:1126] /data2/users/gjacob/work/distiller/distiller/quantization/range_linear.py:1126: DeprecationWarning: Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.\n",
      "  DeprecationWarning)\n",
      "\n",
      "W1030 16:35:44.896657 140055302981376 range_linear.py:1093] /data2/users/gjacob/work/distiller/distiller/quantization/range_linear.py:1093: DeprecationWarning: Argument 'fp16' is deprecated. Please use 'fpq_module'(=16/32/64) argument.\n",
      "  DeprecationWarning)\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DistillerRNNModel(\n",
       "  (encoder): FP16Wrapper(\n",
       "    (wrapped_module): Embedding(33278, 1500)\n",
       "  )\n",
       "  (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n",
       "  (decoder): FPWrapper(\n",
       "    (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "overrides_yaml = \"\"\"\n",
    ".*eltwise.*:\n",
    "    fp16: true\n",
    "encoder:\n",
    "    fp16: true\n",
    "decoder:\n",
    "    fp16: true\n",
    "\"\"\"\n",
    "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n",
    "quantizer = PostTrainLinearQuantizer(\n",
    "    deepcopy(man_model),\n",
    "    model_activation_stats=stats_file,\n",
    "    mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n",
    "    overrides=overrides,\n",
    "    per_channel_wts=True\n",
    ")\n",
    "quantizer.prepare_model(dummy_input)\n",
    "quantizer.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [01:45<00:00,  6.27it/s, val_loss=4.47, ppl=87.3] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:4.463292\t|\t ppl:   86.77\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "val_loss = evaluate(quantizer.model.to(device), val_data)\n",
    "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The accuracy is still holding up very well, even though we quantized the inner linear layers!  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try 4: Clipping Activations\n",
    "\n",
    "Now, lets try to choose different boundaries for `min`, `max`.  \n",
    "Instead of using absolute ones, we take the average of all batches (`avg_min`, `avg_max`), which is an indication of where usually most of the boundaries lie. This is done by specifying the `clip_acts` parameter to `ClipMode.AVG` or `\"AVG\"` in the quantizer ctor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [03:04<00:00,  3.61it/s, val_loss=4.49, ppl=89.5] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:4.488176\t|\t ppl:   88.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "overrides_yaml = \"\"\"\n",
    "encoder:\n",
    "    fp16: true\n",
    "decoder:\n",
    "    fp16: true\n",
    "\"\"\"\n",
    "overrides = distiller.utils.yaml_ordered_load(overrides_yaml)\n",
    "quantizer = PostTrainLinearQuantizer(\n",
    "    deepcopy(man_model),\n",
    "    model_activation_stats=stats_file,\n",
    "    mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n",
    "    overrides=overrides,\n",
    "    per_channel_wts=True,\n",
    "    clip_acts=\"AVG\"\n",
    ")\n",
    "quantizer.prepare_model(dummy_input)\n",
    "val_loss = evaluate(quantizer.model.to(device), val_data)\n",
    "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Even though we quantized all of the layers except the embedding and the decoder - we got almost no accuracy penalty. Lets try quantizing them as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 622/622 [03:05<00:00,  3.54it/s, val_loss=4.49, ppl=89.4] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val_loss:4.486969\t|\t ppl:   88.85\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "quantizer = PostTrainLinearQuantizer(\n",
    "    deepcopy(man_model),\n",
    "    model_activation_stats=stats_file,\n",
    "    mode=LinearQuantMode.ASYMMETRIC_SIGNED,\n",
    "    per_channel_wts=True,\n",
    "    clip_acts=\"AVG\"\n",
    ")\n",
    "quantizer.prepare_model(dummy_input)\n",
    "val_loss = evaluate(quantizer.model.to(device), val_data)\n",
    "print('val_loss:%8.6f\\t|\\t ppl:%8.2f' % (val_loss, np.exp(val_loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DistillerRNNModel(\n",
       "  (encoder): RangeLinearEmbeddingWrapper(\n",
       "    (wrapped_module): Embedding(33278, 1500)\n",
       "  )\n",
       "  (rnn): DistillerLSTM(1500, 1500, num_layers=2, dropout=0.65, bidirectional=False)\n",
       "  (decoder): RangeLinearQuantParamLayerWrapper(\n",
       "    weights_quant_settings=(num_bits=8 ; quant_mode=ASYMMETRIC_SIGNED ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=True)\n",
       "    output_quant_settings=(num_bits=8 ; quant_mode=ASYMMETRIC_SIGNED ; clip_mode=AVG ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    accum_quant_settings=(num_bits=32 ; quant_mode=SYMMETRIC ; clip_mode=NONE ; clip_n_stds=None ; clip_half_range=False ; per_channel=False)\n",
       "    requires_quantized_inputs=True\n",
       "      inputs_quant_auto_fallback=True, forced_quant_settings_for_inputs=None\n",
       "    scale_approx_mult_bits=None\n",
       "    preset_activation_stats=True\n",
       "      output_scale=9.939270, output_zero_point=56.000000\n",
       "    weights_scale=PerCh, weights_zero_point=PerCh\n",
       "    (wrapped_module): Linear(in_features=1500, out_features=33278, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "quantizer.model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we see that sometimes quantizing with the right boundaries gives better results than actually using floating point operations (even though they are half precision). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Choosing the right boundaries for quantization  was crucial for achieving almost no degradation in accrucay of LSTM.  \n",
    "  \n",
    "Here we showed how to use the Distiller quantization API to quantize an RNN model, by converting the PyTorch implementation into a modular one and then quantizing each layer separately."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. **Jongsoo Park, Maxim Naumov, Protonu Basu, Summer Deng, Aravind Kalaiah, Daya Khudia, James Law, Parth Malani, Andrey Malevich, Satish Nadathur, Juan Miguel Pino, Martin Schatz, Alexander Sidorov, Viswanath Sivakumar, Andrew Tulloch, Xiaodong Wang, Yiming Wu, Hector Yuen, Utku Diril, Dmytro Dzhulgakov, Kim Hazelwood, Bill Jia, Yangqing Jia, Lin Qiao, Vijay Rao, Nadav Rotem, Sungjoo Yoo, Mikhail Smelyanskiy**. Deep Learning Inference in Facebook Data Centers: Characterization, Performance Optimizations and Hardware Implications. [arxiv:1811.09886](https://arxiv.org/abs/1811.09886)\n",
    "\n",
    "2. **Qinyao He, He Wen, Shuchang Zhou, Yuxin Wu, Cong Yao, Xinyu Zhou, Yuheng Zou**. Effective Quantization Methods for Recurrent Neural Networks. [arxiv:1611.10176](https://arxiv.org/abs/1611.10176)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
