{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "import RVQE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_num_threads(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our goal is to create a RNN or LSTM with roughly 1965 parameters, and compare it in the dna long sequence task implemented within RVQE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_t = lambda length: RVQE.datasets.all_datasets[\"dna\"](0, num_shards=0, batch_size=16, sentence_length=length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "def to_one_hot(labels, num_classes=2**3):\n",
    "    return torch.eye(num_classes)[labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = [9120, 2783, 2057, 6549, 3201, 7063, 5243, 3102, 5303, 5819, 3693, 4884, 2231, 5514, 8850, 6861, 3106, 2378, 8697, 1821, 9480, 8483, 1633, 9678, 6596, 4509, 8618, 9765, 6346, 2969];\n",
    "LENGTHS = [5, 10, 20, 50, 100, 200, 500, 1000];"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleLSTM(nn.Module):\n",
    "    \"\"\"\n",
    "        This is a very simplistic LSTM setup. We found a single layer performs\n",
    "        much better than two layers with a smaller hidden size.\n",
    "    \"\"\"\n",
    "    def __init__(self, io_size=2**3, hidden_size=40):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.rnn = nn.LSTM(input_size=io_size, hidden_size=hidden_size, num_layers=2, batch_first=True)\n",
    "        self.lin = nn.Linear(hidden_size, io_size)\n",
    "        \n",
    "    def reset(self):\n",
    "        self.lin.reset_parameters()\n",
    "        for name, param in self.rnn.named_parameters():\n",
    "            # give an orthogonal start\n",
    "            if \"weight_hh\" in name:\n",
    "                # stacked\n",
    "                h = param.data.shape[1]\n",
    "                for i in range(4):\n",
    "                    torch.nn.init.orthogonal_(param.data[h*i : h*(i+1), :])\n",
    "            elif \"bias\" in name:\n",
    "                param.data.fill_(0)\n",
    "            elif \"weight_ih\" in name:\n",
    "                torch.nn.init.xavier_uniform_(param.data)\n",
    "            else:\n",
    "                raise Exception(f\"cannot initialize {name}\")\n",
    "        \n",
    "    @property\n",
    "    def num_parameters(self):\n",
    "        return count_parameters(self.rnn) + count_parameters(self.lin)\n",
    "        \n",
    "    def forward(self, sentence):\n",
    "        rnn_out, _ = self.rnn(sentence)\n",
    "        return self.lin(rnn_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21448"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SimpleLSTM().num_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "created LSTM with 21448 parameters\n",
      "000500 5.76e-01\n",
      "001000 5.88e-01\n",
      "001500 2.75e-01\n",
      "002000 3.17e-01\n",
      "length 1000 converged after 2485 steps.\n",
      "000500 1.34e-03\n",
      "length 1000 converged after 664 steps.\n",
      "000500 2.12e-02\n",
      "length 1000 converged after 728 steps.\n",
      "000500 7.04e-01\n",
      "001000 1.52e-03\n",
      "length 1000 converged after 1176 steps.\n",
      "000500 1.23e+00\n",
      "001000 1.41e+00\n",
      "001500 1.25e+00\n",
      "002000 1.37e+00\n",
      "002500 1.31e+00\n",
      "003000 1.25e+00\n",
      "003500 1.21e+00\n",
      "004000 1.22e+00\n"
     ]
    }
   ],
   "source": [
    "for length in LENGTHS:\n",
    "    \n",
    "    dataset = dataset_t(length)\n",
    "    print(f\"created LSTM with {SimpleLSTM().num_parameters} parameters\")\n",
    "    \n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    results[length] = results[length] if length in results else []\n",
    "    \n",
    "    for seed in SEEDS[:5]:\n",
    "        if seed in [ s for s, _ in results[length] ]:\n",
    "            continue\n",
    "        \n",
    "        torch.manual_seed(seed)\n",
    "        model = SimpleLSTM()\n",
    "        model.reset()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.05)   # this has been found to converge fastest\n",
    "        \n",
    "        for step in range(1, 16*1000): # cap amounts to the same number of samples seen as for qrnn\n",
    "            sentence, target = dataset.next_batch(0, RVQE.data.TrainingStage.TRAIN)\n",
    "            \n",
    "            # transform sentence to one-hot as in the qrnn case\n",
    "            sentence = to_one_hot(RVQE.data.targets_for_loss(sentence))            \n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            out = model(sentence.float())\n",
    "            \n",
    "            # unlike the qrnn case, we use the entire output as loss\n",
    "            # this gives the rnn an advantage!\n",
    "            out = out.transpose(1, 2)\n",
    "            target = RVQE.data.targets_for_loss(target)\n",
    "            loss = criterion(out, target)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            if torch.isnan(loss):\n",
    "                print(\"nan\")\n",
    "                results[length].append([seed, 16*1000])\n",
    "                break\n",
    "            \n",
    "            if loss < 0.0005:\n",
    "                results[length].append([seed, step])\n",
    "                print(f\"length {length} converged after {step} steps.\")\n",
    "                break\n",
    "            \n",
    "            if step % 500 == 0:\n",
    "                pass\n",
    "                print(f\"{step:06d} {loss:.2e}\")\n",
    "                \n",
    "        else:\n",
    "            print(f\"length {length} did not converge after step steps.\")\n",
    "            results[length].append([seed, step])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pd.DataFrame([ [key, seed, step, .0] for key in results for seed, step in results[key] ], columns=[\"sentence_length\", \"seed\", \"hparams/epoch\", \"hparams/validate_best\"], index=None).to_csv(\"~/long-lstm.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_items([(5, [[9120, 52], [2783, 46], [2057, 33], [6549, 35], [3201, 42], [7063, 39], [5243, 44], [3102, 63], [5303, 39], [5819, 30], [3693, 36], [4884, 72], [2231, 44], [5514, 32], [8850, 35], [6861, 58], [3106, 43], [2378, 30], [8697, 44], [1821, 38], [9480, 40], [8483, 49], [1633, 35], [9678, 34], [6596, 34], [4509, 73], [8618, 42], [9765, 48], [6346, 56], [2969, 40]]), (10, [[9120, 147], [2783, 427], [2057, 183], [6549, 283], [3201, 205], [7063, 215], [5243, 123], [3102, 195], [5303, 72], [5819, 146], [3693, 296], [4884, 164], [2231, 211], [5514, 93], [8850, 196], [6861, 173], [3106, 323], [2378, 86], [8697, 105], [1821, 111], [9480, 214], [8483, 275], [1633, 102], [9678, 300], [6596, 145], [4509, 195], [8618, 188], [9765, 370], [6346, 206], [2969, 191]]), (20, [[9120, 742], [2783, 424], [2057, 338], [6549, 702], [3201, 292], [7063, 187], [5243, 411], [3102, 669], [5303, 366], [5819, 179], [3693, 244], [4884, 600], [2231, 365], [5514, 306], [8850, 317], [6861, 565], [3106, 317], [2378, 427], [8697, 298], [1821, 204], [9480, 407], [8483, 357], [1633, 345], [9678, 633], [6596, 521], [4509, 269], [8618, 578], [9765, 265], [6346, 432], [2969, 296]]), (50, [[9120, 624], [2783, 471], [2057, 191], [6549, 425], [3201, 472], [7063, 319], [5243, 977], [3102, 710], [5303, 274], [5819, 284], [3693, 363], [4884, 390], [2231, 703], [5514, 432], [8850, 510], [6861, 373], [3106, 533], [2378, 261], [8697, 574], [1821, 601], [9480, 303], [8483, 750], [1633, 835], [9678, 2246], [6596, 438], [4509, 553], [8618, 784], [9765, 489], [6346, 774], [2969, 1110]]), (100, [[9120, 559], [2783, 1514], [2057, 1170], [6549, 1444], [3201, 4095], [7063, 1127], [5243, 357], [3102, 985], [5303, 1790], [5819, 210], [3693, 2717], [4884, 746], [2231, 1027], [5514, 899], [8850, 876], [6861, 515], [3106, 15999], [2378, 733], [8697, 672], [1821, 305], [9480, 636], [8483, 691], [1633, 844], [9678, 380], [6596, 804], [4509, 599], [8618, 885], [9765, 1223], [6346, 2909], [2969, 2600]]), (200, [[9120, 1443], [2783, 468], [2057, 2254], [6549, 520], [3201, 1099], [7063, 1265], [5243, 2187], [3102, 492], [5303, 1236], [5819, 902], [3693, 477], [4884, 853], [2231, 792], [5514, 504], [8850, 923], [6861, 2452], [3106, 2227], [2378, 643], [8697, 595], [1821, 15999], [9480, 1167], [8483, 8949], [1633, 1610], [9678, 788], [6596, 627], [4509, 1297], [8618, 15999], [9765, 1569], [6346, 1206], [2969, 1305]]), (500, [[9120, 1543], [2783, 6640], [2057, 929], [6549, 2752], [3201, 1285]]), (1000, [[9120, 2485], [2783, 664], [2057, 728], [6549, 1176]])])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results.items()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:rvqe]",
   "language": "python",
   "name": "conda-env-rvqe-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
