{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import RVQE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_num_threads(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our goal is to create a RNN or LSTM with roughly 1965 parameters, and compare it in the dna long sequence task implemented within RVQE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_t = lambda length: RVQE.datasets.all_datasets[\"dna\"](0, num_shards=0, batch_size=16, sentence_length=length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "def to_one_hot(labels, num_classes=2**3):\n",
    "    return torch.eye(num_classes)[labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = [9120, 2783, 2057, 6549, 3201, 7063, 5243, 3102, 5303, 5819, 3693, 4884, 2231, 5514, 8850, 6861, 3106, 2378, 8697, 1821, 9480, 8483, 1633, 9678, 6596, 4509, 8618, 9765, 6346, 2969];\n",
    "LENGTHS = [5, 10, 20, 50, 100, 200, 500, 1000];"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRNN(nn.Module):\n",
    "    \"\"\"\n",
    "        This is a very simplistic RNN setup. We found a single layer performs\n",
    "        much better than two layers with a smaller hidden size.\n",
    "        Without doubt one can improve the performance of this model.\n",
    "        Yet we didn't optimize the QRNN setup for the task at hand either.\n",
    "    \"\"\"\n",
    "    def __init__(self, io_size=2**3, hidden_size=80):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.rnn = nn.RNN(input_size=io_size, hidden_size=hidden_size, num_layers=2, batch_first=True)\n",
    "        self.lin = nn.Linear(hidden_size, io_size)\n",
    "        \n",
    "    def reset(self):\n",
    "        self.lin.reset_parameters()\n",
    "        for name, param in self.rnn.named_parameters():\n",
    "            # give an orthogonal start\n",
    "            if \"weight_hh\" in name:\n",
    "                torch.nn.init.orthogonal_(param.data)\n",
    "            elif \"bias\" in name:\n",
    "                param.data.fill_(0)\n",
    "            elif \"weight_ih\" in name:\n",
    "                torch.nn.init.xavier_uniform_(param.data)\n",
    "            else:\n",
    "                raise Exception(f\"cannot initialize {name}\")\n",
    "        \n",
    "    @property\n",
    "    def num_parameters(self):\n",
    "        return count_parameters(self.rnn) + count_parameters(self.lin)\n",
    "        \n",
    "    def forward(self, sentence):\n",
    "        rnn_out, _ = self.rnn(sentence)\n",
    "        return self.lin(rnn_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20808"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SimpleRNN().num_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "creating RNN with 20808 parameters\n",
      "creating RNN with 20808 parameters\n",
      "creating RNN with 20808 parameters\n",
      "creating RNN with 20808 parameters\n",
      "000500 1.34e+00\n",
      "001000 1.37e+00\n",
      "001500 1.53e+00\n",
      "002000 1.36e+00\n",
      "002500 1.55e+00\n",
      "003000 1.31e+00\n",
      "003500 1.33e+00\n",
      "004000 1.42e+00\n",
      "004500 1.57e+00\n",
      "005000 1.39e+00\n",
      "005500 1.59e+00\n",
      "006000 1.33e+00\n",
      "006500 1.32e+00\n",
      "007000 1.37e+00\n",
      "007500 1.50e+00\n",
      "008000 1.44e+00\n",
      "008500 1.39e+00\n",
      "009000 1.58e+00\n",
      "009500 1.28e+00\n",
      "010000 1.39e+00\n",
      "010500 1.41e+00\n",
      "011000 1.23e+00\n",
      "011500 1.22e+00\n",
      "012000 1.53e+00\n",
      "012500 1.41e+00\n",
      "013000 1.32e+00\n",
      "013500 1.39e+00\n",
      "014000 1.43e+00\n",
      "014500 1.26e+00\n",
      "015000 1.49e+00\n",
      "015500 1.45e+00\n",
      "length 50 did not converge after step steps.\n",
      "000500 1.43e+00\n",
      "001000 1.43e+00\n",
      "001500 1.37e+00\n",
      "002000 1.32e+00\n",
      "002500 1.27e+00\n",
      "003000 1.22e+00\n",
      "003500 1.43e+00\n",
      "004000 1.16e+00\n",
      "004500 1.34e+00\n",
      "005000 1.28e+00\n",
      "005500 1.14e+00\n",
      "006000 1.47e+00\n",
      "006500 1.27e+00\n",
      "007000 1.35e+00\n",
      "007500 1.28e+00\n",
      "008000 1.28e+00\n",
      "008500 1.39e+00\n",
      "009000 1.36e+00\n",
      "009500 1.32e+00\n",
      "010000 1.89e+00\n",
      "010500 1.49e+00\n",
      "011000 1.41e+00\n",
      "011500 1.35e+00\n",
      "012000 1.34e+00\n",
      "012500 1.33e+00\n",
      "013000 1.42e+00\n",
      "013500 1.39e+00\n",
      "014000 1.41e+00\n",
      "014500 1.55e+00\n",
      "015000 1.52e+00\n",
      "015500 1.32e+00\n",
      "length 50 did not converge after step steps.\n",
      "000500 1.49e+00\n",
      "001000 1.42e+00\n",
      "001500 1.44e+00\n",
      "002000 1.50e+00\n",
      "002500 1.22e+00\n",
      "003000 1.18e+00\n",
      "003500 1.25e+00\n",
      "004000 1.12e+00\n",
      "004500 1.25e+00\n",
      "005000 1.15e+00\n",
      "005500 1.21e+00\n",
      "006000 1.26e+00\n",
      "006500 1.22e+00\n",
      "007000 1.21e+00\n",
      "007500 1.16e+00\n",
      "008000 1.14e+00\n",
      "008500 1.29e+00\n",
      "009000 1.21e+00\n",
      "009500 1.26e+00\n",
      "010000 1.14e+00\n",
      "010500 1.19e+00\n",
      "011000 1.22e+00\n",
      "011500 1.23e+00\n",
      "012000 1.14e+00\n",
      "012500 1.30e+00\n",
      "013000 1.23e+00\n",
      "013500 1.12e+00\n",
      "014000 1.25e+00\n",
      "014500 1.22e+00\n",
      "015000 1.22e+00\n",
      "015500 1.24e+00\n",
      "length 50 did not converge after step steps.\n",
      "000500 1.85e+00\n",
      "001000 1.50e+00\n",
      "001500 1.57e+00\n",
      "002000 1.60e+00\n",
      "002500 1.37e+00\n",
      "003000 1.47e+00\n",
      "003500 1.36e+00\n",
      "004000 1.51e+00\n",
      "004500 1.61e+00\n",
      "005000 1.37e+00\n",
      "005500 1.34e+00\n",
      "006000 1.46e+00\n",
      "006500 1.31e+00\n",
      "007000 1.34e+00\n",
      "007500 1.24e+00\n",
      "008000 1.43e+00\n",
      "008500 1.56e+00\n",
      "009000 1.29e+00\n",
      "009500 1.28e+00\n",
      "010000 1.35e+00\n",
      "010500 1.23e+00\n",
      "011000 1.31e+00\n",
      "011500 1.30e+00\n",
      "012000 1.30e+00\n",
      "012500 1.39e+00\n",
      "013000 1.44e+00\n",
      "013500 1.50e+00\n",
      "014000 1.47e+00\n",
      "014500 1.43e+00\n",
      "015000 1.49e+00\n",
      "015500 1.46e+00\n",
      "length 50 did not converge after step steps.\n",
      "000500 1.53e+00\n",
      "001000 1.26e+00\n",
      "001500 1.47e+00\n",
      "002000 1.27e+00\n",
      "002500 1.38e+00\n",
      "003000 1.38e+00\n",
      "003500 1.35e+00\n",
      "004000 1.24e+00\n",
      "004500 1.39e+00\n",
      "005000 1.37e+00\n",
      "005500 1.71e+00\n",
      "006000 1.44e+00\n",
      "006500 1.27e+00\n",
      "007000 1.22e+00\n",
      "007500 1.35e+00\n",
      "008000 1.92e+00\n",
      "008500 1.65e+00\n",
      "009000 1.33e+00\n",
      "009500 1.51e+00\n",
      "010000 1.31e+00\n",
      "010500 1.51e+00\n",
      "011000 1.44e+00\n",
      "011500 1.41e+00\n",
      "012000 1.55e+00\n",
      "012500 1.33e+00\n",
      "013000 1.39e+00\n",
      "013500 1.42e+00\n",
      "014000 1.46e+00\n",
      "014500 1.38e+00\n",
      "015000 1.29e+00\n",
      "015500 1.45e+00\n",
      "length 50 did not converge after step steps.\n",
      "000500 1.68e+00\n",
      "001000 1.46e+00\n"
     ]
    }
   ],
   "source": [
    "for length in LENGTHS:\n",
    "    \n",
    "    dataset = dataset_t(length)\n",
    "    print(f\"creating RNN with {SimpleRNN().num_parameters} parameters\")\n",
    "    \n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    results[length] = results[length] if length in results else []\n",
    "    \n",
    "    for seed in SEEDS:\n",
    "        if seed in [ s for s, _ in results[length] ]:\n",
    "            continue\n",
    "            \n",
    "        torch.manual_seed(seed)\n",
    "        model = SimpleRNN()\n",
    "        model.reset()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.01)  # this has been found to converge fastest\n",
    "        \n",
    "        for step in range(1, 16 * 1000):  # cap amounts to the same number of samples seen as for qrnn\n",
    "            sentence, target = dataset.next_batch(0, RVQE.data.TrainingStage.TRAIN)\n",
    "            \n",
    "            # transform sentence to one-hot as in the qrnn case\n",
    "            sentence = to_one_hot(RVQE.data.targets_for_loss(sentence))            \n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            out = model(sentence.float())\n",
    "            \n",
    "            # unlike the qrnn case, we use the entire output as loss\n",
    "            # this gives the rnn an advantage!\n",
    "            out = out.transpose(1, 2)\n",
    "            target = RVQE.data.targets_for_loss(target)\n",
    "            loss = criterion(out, target)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            if loss < 0.0005:\n",
    "                results[length].append([seed, step])\n",
    "                print(f\"length {length} converged after {step} steps.\")\n",
    "                break\n",
    "            \n",
    "            if step % 500 == 0:\n",
    "                print(f\"{step:06d} {loss:.2e}\")\n",
    "                \n",
    "        else:\n",
    "            print(f\"length {length} did not converge after step steps.\")\n",
    "            results[length].append([seed, step])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame([ [key, seed, step, .0] for key in results for seed, step in results[key] ], columns=[\"sentence_length\", \"seed\", \"hparams/epoch\", \"hparams/validate_best\"], index=None).to_csv(\"~/long-rnn.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_items([(5, [[9120, 48], [2783, 45], [2057, 43], [6549, 57], [3201, 53], [7063, 65], [5243, 44], [3102, 41], [5303, 53], [5819, 50], [3693, 47], [4884, 47], [2231, 49], [5514, 46], [8850, 58], [6861, 42], [3106, 40], [2378, 68], [8697, 44], [1821, 46], [9480, 47], [8483, 53], [1633, 53], [9678, 49], [6596, 43], [4509, 43], [8618, 46], [9765, 46], [6346, 44], [2969, 49]]), (10, [[9120, 386], [2783, 276], [2057, 285], [6549, 304], [3201, 387], [7063, 432], [5243, 216], [3102, 352], [5303, 298], [5819, 415], [3693, 262], [4884, 317], [2231, 386], [5514, 342], [8850, 436], [6861, 424], [3106, 294], [2378, 285], [8697, 331], [1821, 348], [9480, 299], [8483, 419], [1633, 374], [9678, 401], [6596, 412], [4509, 422], [8618, 385], [9765, 277], [6346, 602], [2969, 302]]), (20, [[9120, 15999], [2783, 15999], [2057, 15999], [6549, 15999], [3201, 15999], [7063, 15999], [5243, 15999], [3102, 15999], [5303, 15999], [5819, 15999], [3693, 15999], [4884, 15999], [2231, 15999], [5514, 15999], [8850, 15999], [6861, 15999], [3106, 15999], [2378, 15999], [8697, 15999], [1821, 15999], [9480, 15999], [8483, 15999], [1633, 15999], [9678, 15999], [6596, 15999], [4509, 15999], [8618, 15999], [9765, 15999], [6346, 15999], [2969, 15999]]), (50, [[9120, 15999], [2783, 15999], [2057, 15999], [6549, 15999], [3201, 15999], [7063, 15999], [5243, 15999], [3102, 15999], [5303, 15999], [5819, 15999], [3693, 15999], [4884, 15999], [2231, 15999], [5514, 15999], [8850, 15999], [6861, 15999], [3106, 15999], [2378, 15999], [8697, 15999], [1821, 15999], [9480, 15999], [8483, 15999], [1633, 15999], [9678, 15999], [6596, 15999], [4509, 15999], [8618, 15999], [9765, 15999], [6346, 15999], [2969, 15999]]), (100, [[9120, 15999], [2783, 15999], [2057, 15999], [6549, 15999], [3201, 15999], [7063, 15999], [5243, 15999], [3102, 15999], [5303, 15999], [5819, 15999], [3693, 15999], [4884, 15999]])])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.items() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:rvqe]",
   "language": "python",
   "name": "conda-env-rvqe-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
