{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Complexity of large contractions\n",
    "\n",
    "This notebook assesses computational complexity of `einsum` on large contractions. To generate data (pasted in the cells below), run:\n",
    "```sh\n",
    "GROWTH_SIZE=50 pytest -s tests/infer/test_enum.py -k growth\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = None\n",
    "costs = None\n",
    "times1 = None\n",
    "times2 = None\n",
    "\n",
    "def plot(title):\n",
    "    pyplot.figure(figsize=(8,5)).patch.set_color('white')\n",
    "    pyplot.title('{} data structures'.format(title))\n",
    "    for name, series in sorted(costs.items()):\n",
    "        pyplot.plot(sizes, series, label=name)\n",
    "    pyplot.xlabel('problem size')\n",
    "    pyplot.xlim(0, max(sizes))\n",
    "    pyplot.legend(loc='best')\n",
    "    pyplot.tight_layout()\n",
    "\n",
    "    pyplot.figure(figsize=(8,5)).patch.set_color('white')\n",
    "    pyplot.title('{} run time'.format(title))\n",
    "    pyplot.plot(sizes, times1, label='optim + compute')\n",
    "    pyplot.plot(sizes, times2, label='compute')\n",
    "    pyplot.xlim(0, max(sizes))\n",
    "    pyplot.xlabel('problem size')\n",
    "    pyplot.ylabel('time (sec)')\n",
    "    pyplot.legend(loc='best')\n",
    "    pyplot.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]\n",
    "costs = {'einsum': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48], 'tensordot': [12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 168, 172, 176, 180, 184, 188, 192, 196, 200], 'tensor': [22, 31, 40, 49, 58, 67, 76, 85, 94, 103, 112, 121, 130, 139, 148, 157, 166, 175, 184, 193, 202, 211, 220, 229, 238, 247, 256, 265, 274, 283, 292, 301, 310, 319, 328, 337, 346, 355, 364, 373, 382, 391, 400, 409, 418, 427, 436, 445]}\n",
    "times1 = [0.01864790916442871, 0.015166997909545898, 0.017799854278564453, 0.021364927291870117, 0.0234529972076416, 0.03243708610534668, 0.03485298156738281, 0.03809309005737305, 0.04254293441772461, 0.043493032455444336, 0.04782605171203613, 0.051072120666503906, 0.05495715141296387, 0.06077980995178223, 0.06451010704040527, 0.06647181510925293, 0.07750391960144043, 0.10012388229370117, 0.09436392784118652, 0.08780503273010254, 0.09475111961364746, 0.08931398391723633, 0.1099538803100586, 0.10660696029663086, 0.10943722724914551, 0.11156201362609863, 0.11216998100280762, 0.11894893646240234, 0.12170791625976562, 0.1290268898010254, 0.13869500160217285, 0.1344318389892578, 0.13837814331054688, 0.14883112907409668, 0.14552593231201172, 0.1480569839477539, 0.14761590957641602, 0.15995121002197266, 0.16048288345336914, 0.16365408897399902, 0.16843199729919434, 0.2130718231201172, 0.17986297607421875, 0.1792001724243164, 0.1941969394683838, 0.2153019905090332, 0.20756793022155762, 0.19938111305236816]\n",
    "times2 = [0.010827064514160156, 0.014249086380004883, 0.016450166702270508, 0.020006895065307617, 0.025799989700317383, 0.02879500389099121, 0.03235912322998047, 0.036743879318237305, 0.04072308540344238, 0.04432511329650879, 0.04558587074279785, 0.051867008209228516, 0.05726289749145508, 0.058149099349975586, 0.06532096862792969, 0.0634920597076416, 0.07218098640441895, 0.12434697151184082, 0.07972311973571777, 0.08487296104431152, 0.08191704750061035, 0.13434886932373047, 0.10629105567932129, 0.10842609405517578, 0.10170793533325195, 0.10760092735290527, 0.11115694046020508, 0.1158750057220459, 0.12462496757507324, 0.1272139549255371, 0.13429498672485352, 0.1305849552154541, 0.14617490768432617, 0.18872499465942383, 0.1460709571838379, 0.13549304008483887, 0.1373729705810547, 0.15271997451782227, 0.15703701972961426, 0.1608130931854248, 0.21175909042358398, 0.18168210983276367, 0.17579412460327148, 0.17799592018127441, 0.1961660385131836, 0.20264911651611328, 0.25041794776916504, 0.1808319091796875]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot('HMM')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sizes = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]\n",
    "costs = {'einsum': [7, 9, 13, 15, 19, 21, 25, 27, 31, 33, 37, 39, 43, 45, 49, 51, 55, 57, 61, 63, 67, 69, 73, 75, 79, 81, 85, 87, 91, 93, 97, 99, 103, 105, 109, 111, 115, 117, 121, 123, 127, 129, 133, 135, 139, 141, 145, 147], 'tensordot': [18, 25, 30, 37, 42, 49, 54, 61, 66, 73, 78, 85, 90, 97, 102, 109, 114, 121, 126, 133, 138, 145, 150, 157, 162, 169, 174, 181, 186, 193, 198, 205, 210, 217, 222, 229, 234, 241, 246, 253, 258, 265, 270, 277, 282, 289, 294, 301], 'tensor': [46, 63, 80, 97, 114, 131, 148, 165, 182, 199, 216, 233, 250, 267, 284, 301, 318, 335, 352, 369, 386, 403, 420, 437, 454, 471, 488, 505, 522, 539, 556, 573, 590, 607, 624, 641, 658, 675, 692, 709, 726, 743, 760, 777, 794, 811, 828, 845]}\n",
    "times1 = [0.02198004722595215, 0.03037405014038086, 0.03350090980529785, 0.04224896430969238, 0.04834318161010742, 0.05909299850463867, 0.06626009941101074, 0.08351302146911621, 0.09097099304199219, 0.08897876739501953, 0.09535503387451172, 0.10136294364929199, 0.13000011444091797, 0.12712597846984863, 0.13105392456054688, 0.1476750373840332, 0.14663481712341309, 0.15439701080322266, 0.15521693229675293, 0.1650080680847168, 0.1742238998413086, 0.17893004417419434, 0.18517208099365234, 0.19159197807312012, 0.20879316329956055, 0.2737429141998291, 0.23352789878845215, 0.22190213203430176, 0.23365497589111328, 0.23900103569030762, 0.2523791790008545, 0.26091718673706055, 0.2820899486541748, 0.3140451908111572, 0.28127598762512207, 0.2906830310821533, 0.34561610221862793, 0.4711790084838867, 0.3032550811767578, 0.31789112091064453, 0.34140491485595703, 0.34586501121520996, 0.3419170379638672, 0.35588693618774414, 0.36873412132263184, 0.36976003646850586, 0.3961608409881592, 0.3883850574493408]\n",
    "times2 = [0.022389888763427734, 0.026437997817993164, 0.03232693672180176, 0.041667938232421875, 0.05160379409790039, 0.055931806564331055, 0.07128310203552246, 0.08053183555603027, 0.08122706413269043, 0.0810542106628418, 0.0922250747680664, 0.10212492942810059, 0.11983704566955566, 0.1128089427947998, 0.13935494422912598, 0.12748098373413086, 0.13879609107971191, 0.1859588623046875, 0.14890193939208984, 0.15740394592285156, 0.16302895545959473, 0.17653393745422363, 0.1802539825439453, 0.18121719360351562, 0.20098400115966797, 0.19684600830078125, 0.2023460865020752, 0.22677183151245117, 0.23773717880249023, 0.23118090629577637, 0.23914885520935059, 0.2430558204650879, 0.31301093101501465, 0.2789499759674072, 0.26804518699645996, 0.28461790084838867, 0.3887619972229004, 0.31357502937316895, 0.2947719097137451, 0.3141598701477051, 0.4249720573425293, 0.32235097885131836, 0.3292689323425293, 0.32982301712036133, 0.39942502975463867, 0.3410038948059082, 0.3757472038269043, 0.38117194175720215]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot('DBN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
