{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from experiment import Experiment, MethodNest, Job\n",
    "from hyperbox import Hyperbox\n",
    "from relu_nets import ReLUNet\n",
    "from neural_nets import data_loaders as dl\n",
    "from neural_nets import train\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from other_methods import LOCAL_METHODS, GLOBAL_METHODS, OTHER_METHODS\n",
    "from utilities import Factory, DoEvery\n",
    "import utilities as utils\n",
    "import os\n",
    "import time\n",
    "import pickle \n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('exp_3_random.pkl', 'rb') as f:\n",
    "    rand_results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('exp_3_synthetic.pkl', 'rb') as f:\n",
    "    synth_results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_collector(result_list):\n",
    "    output = {}\n",
    "    for result in result_list:\n",
    "        for k, v in result.items():\n",
    "            output[k] = output.get(k, []) + [v]\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_run = run_collector(rand_results)\n",
    "synth_run = run_collector(synth_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def series_getter(run, fxn):\n",
    "    if fxn is None:\n",
    "        fxn = lambda i: i\n",
    "    return {k: fxn(v, run) for k,v in run.items()}\n",
    "\n",
    "def value_fxn(vlist, run=None):\n",
    "    return np.array([list(v.values().values())[0] for v in vlist])\n",
    "\n",
    "def time_fxn(vlist, run=None):\n",
    "    return np.array([list(v.compute_times().values())[0] for v in vlist])    \n",
    "\n",
    "def err_fxn(vlist, run=None):\n",
    "    true_vals = value_fxn(run[0])\n",
    "    this_vals = value_fxn(vlist)\n",
    "    return np.array([this / true for this, true in zip(this_vals, true_vals)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_values = series_getter(rand_run, value_fxn)\n",
    "rand_times = series_getter(rand_run, time_fxn)\n",
    "rand_errors = series_getter(rand_run, err_fxn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rowmaker(times, vals, errs):\n",
    "    mean_format = lambda x,y: '${:10.3f} \\pm {:10.3f}$'.format(x,y)\n",
    "    def err_format(e):\n",
    "        s = ''\n",
    "        if e > 1.0:\n",
    "            s = '+'\n",
    "        return '$' + s + '{:10.2f} \\%$'.format(e * 100 - 100.0).lstrip()\n",
    "    k_order = ['LP', 1.0, 0.1, 0.01, 0.0]\n",
    "    \n",
    "    rows = []\n",
    "    for k in k_order:\n",
    "        row = [k]\n",
    "        row.append(mean_format(np.median(times[k]), np.std(times[k])))\n",
    "        row.append(mean_format(np.median(vals[k]), np.std(vals[k])))\n",
    "        row.append(err_format(np.median(errs[k])))\n",
    "        rows.append(row)\n",
    "    return rows\n",
    "\n",
    "def write_rows_to_csv(rows, filename):\n",
    "    with open(filename, 'w', newline='') as f:\n",
    "        csvwriter = csv.writer(f, delimiter=',',\n",
    "                                quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
    "        for row in rows:\n",
    "            csvwriter.writerow(row)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_rows = rowmaker(rand_times, rand_values, rand_errors)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_values = series_getter(synth_run, value_fxn)\n",
    "synth_times = series_getter(synth_run, time_fxn)\n",
    "synth_errors = series_getter(synth_run, err_fxn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "synth_rows = rowmaker(synth_times, synth_values, synth_errors)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "both_rows = [x + y[1:] for x,y in zip(rand_rows, synth_rows)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['LP',\n",
       "  '$     0.017 \\\\pm      0.003$',\n",
       "  '$     5.428 \\\\pm      1.053$',\n",
       "  '$+457.26 \\\\%$',\n",
       "  '$     0.031 \\\\pm      0.003$',\n",
       "  '$  2023.407 \\\\pm    690.018$',\n",
       "  '$+390.74 \\\\%$'],\n",
       " [1.0,\n",
       "  '$    92.194 \\\\pm    119.857$',\n",
       "  '$     1.874 \\\\pm      0.518$',\n",
       "  '$+95.32 \\\\%$',\n",
       "  '$    10.101 \\\\pm     12.773$',\n",
       "  '$   690.397 \\\\pm    264.017$',\n",
       "  '$+74.81 \\\\%$'],\n",
       " [0.1,\n",
       "  '$   264.243 \\\\pm    294.666$',\n",
       "  '$     1.082 \\\\pm      0.287$',\n",
       "  '$+9.84 \\\\%$',\n",
       "  '$    25.837 \\\\pm     53.403$',\n",
       "  '$   444.495 \\\\pm    189.612$',\n",
       "  '$+9.24 \\\\%$'],\n",
       " [0.01,\n",
       "  '$   266.202 \\\\pm    287.511$',\n",
       "  '$     0.991 \\\\pm      0.266$',\n",
       "  '$+0.49 \\\\%$',\n",
       "  '$    30.915 \\\\pm     63.484$',\n",
       "  '$   411.625 \\\\pm    178.890$',\n",
       "  '$+0.67 \\\\%$'],\n",
       " [0.0,\n",
       "  '$   265.894 \\\\pm    304.685$',\n",
       "  '$     0.988 \\\\pm      0.269$',\n",
       "  '$0.00 \\\\%$',\n",
       "  '$    30.976 \\\\pm     63.721$',\n",
       "  '$   408.434 \\\\pm    178.000$',\n",
       "  '$0.00 \\\\%$']]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "both_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_rows_to_csv(both_rows, 'exp_3_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
