{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np \n",
    "import torch\n",
    "from hyperbox import Hyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from relu_nets import ReLUNet\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from neural_nets import train\n",
    "from neural_nets import data_loaders as dl\n",
    "from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest\n",
    "from utilities import Factory\n",
    "import math\n",
    "import neural_nets.data_loaders as dl \n",
    "import neural_nets.train as train\n",
    "import glob\n",
    "import pickle\n",
    "import csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['l1_exp_2_MNIST_bin_radius02.pkl', 'l1_exp_2_MNIST_bin_radius01.pkl']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glob.glob('l1_exp_2*')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dim_scale(k, val, dim):\n",
    "    if k not in ['SeqLip', 'LipSDP']:\n",
    "        return val\n",
    "    else:\n",
    "        return math.sqrt(dim) *val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_local_globals(result_list_list, global_result, dim):\n",
    "    num_per = [len(_.results) for _ in result_list_list]\n",
    "    \n",
    "    all_results = [_.results for _ in result_list_list]\n",
    "    local_rl = ResultList([_ for el in all_results for _ in el])\n",
    "    global_rl = ResultList(global_result)\n",
    "    local_times = local_rl.average_stdevs('time')\n",
    "    local_vals = local_rl.average_stdevs('value')\n",
    "    global_times = global_rl.average_stdevs('time')\n",
    "    global_vals = global_rl.average_stdevs('value')\n",
    "    local_errs = local_rl.get_rel_err(dim)\n",
    "    all_keys = list(local_times.keys()) + list(global_times.keys())\n",
    "    \n",
    "    # Get the values total dict:\n",
    "    all_values = {}\n",
    "    for val_dict in [local_vals, global_vals]:\n",
    "        for k,v in val_dict.items():\n",
    "            all_values[k] = (dim_scale(k,v[0], dim), dim_scale(k, v[1],dim))\n",
    "    # Get times total dict:\n",
    "    all_times = {}\n",
    "    for time_dict in [local_times, global_times]:\n",
    "        for k,v in time_dict.items():\n",
    "            all_times[k] = v\n",
    "   \n",
    "    # Get right answers in right order:\n",
    "    right_answers = [_.values('LipProblem') for _ in local_rl.results]\n",
    "    rel_err_dict = {}\n",
    "    \n",
    "    for k in global_vals.keys():\n",
    "        answer_idx = 0\n",
    "\n",
    "        if k not in rel_err_dict:\n",
    "            rel_err_dict[k] = []\n",
    "        for num, result in zip(num_per, global_rl.results):\n",
    "            for i in range(num):\n",
    "                right_answer = right_answers[answer_idx]\n",
    "                answer_idx +=1 \n",
    "                rel_err_dict[k].append(dim_scale(k, result.values(k), dim) / right_answer)\n",
    "    global_errs = {k: (np.array(v).mean(), np.array(v).std(), len(v)) for k,v in rel_err_dict.items()}\n",
    "    all_errs = {}\n",
    "\n",
    "    for err_dict in global_errs, local_errs:\n",
    "        for k,v in err_dict.items():\n",
    "            all_errs[k] = v\n",
    "\n",
    "    \n",
    "    # Key-order \n",
    "    key_order = [_[0] for _ in sorted(all_values.items(), key=lambda p:p[1])]\n",
    "    max_len_k = max(len(_) for _ in key_order + ['method'])\n",
    "    pad = lambda s: s + ' '*(max_len_k - len(s))\n",
    "\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-'*len(header))\n",
    "    dformat = lambda s: '|{:10.4f}'.format(s)\n",
    "    mean_format = lambda x,y: '${:10.3f} \\pm {:10.3f}$'.format(x,y)\n",
    "    def err_format(e):\n",
    "        print(e)\n",
    "        s = ''\n",
    "        if e > 1.0:\n",
    "            s = '+'\n",
    "        return '$' + s + '{:10.2f} \\%$'.format(e * 100 - 100.0).lstrip()\n",
    "        \n",
    "    csv_rows = []\n",
    "    for k in key_order:\n",
    "        elements = [pad(k), \n",
    "                    dformat(all_times[k][0]), dformat(all_times[k][1]),\n",
    "                    #dformat(all_values[k][0]), dformat(all_values[k][1]),\n",
    "                    err_format(all_errs[k][0]), #dformat(all_errs[k][1])\n",
    "                   ]\n",
    "        print(' '.join(elements))\n",
    "        csv_row = [pad(k), mean_format(all_times[k][0], all_times[k][1]),\n",
    "                   err_format(all_errs[k][0])]\n",
    "        csv_rows.append(csv_row)\n",
    "    \n",
    "    return csv_rows\n",
    "        \n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "0.5575478\n",
      "RandomLB   |    0.3298 |    0.0057 $-44.25 \\%$\n",
      "0.5575478\n",
      "0.6169179162871867\n",
      "CLEVER     |    6.8555 |    4.9838 $-38.31 \\%$\n",
      "0.6169179162871867\n",
      "1.0\n",
      "LipProblem |    4.5504 |    2.5195 $0.00 \\%$\n",
      "1.0\n",
      "1.434530247850792\n",
      "FastLip    |    0.0014 |    0.0001 $+43.45 \\%$\n",
      "1.434530247850792\n",
      "1.434530247850792\n",
      "LipLP      |    0.2295 |    0.0212 $+43.45 \\%$\n",
      "1.434530247850792\n",
      "10.613136136593488\n",
      "NaiveUB    |    0.0003 |    0.0000 $+961.31 \\%$\n",
      "10.613136136593488\n",
      "162.4715415676271\n",
      "LipSDP     |   18.1838 |    1.9348 $+16147.15 \\%$\n",
      "162.4715415676271\n",
      "166.5965010487548\n",
      "SeqLip     |    0.0128 |    0.0037 $+16559.65 \\%$\n",
      "166.5965010487548\n"
     ]
    }
   ],
   "source": [
    "with open('l1_exp_2_MNIST_bin_radius01.pkl', 'rb') as f:\n",
    "    results1 = pickle.load(f)\n",
    "mnist_1_rows = format_local_globals(results1[0], results1[2], 784)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "0.6492438\n",
      "RandomLB   |    0.3250 |    0.0046 $-35.08 \\%$\n",
      "0.6492438\n",
      "0.7276114496811403\n",
      "CLEVER     |   23.0102 |    0.6115 $-27.24 \\%$\n",
      "0.7276114496811403\n",
      "1.0\n",
      "LipProblem |    4.2924 |    1.1835 $0.00 \\%$\n",
      "1.0\n",
      "1.3265890750665295\n",
      "LipLP      |    0.2332 |    0.0117 $+32.66 \\%$\n",
      "1.3265890750665295\n",
      "1.3265890750665295\n",
      "FastLip    |    0.0016 |    0.0007 $+32.66 \\%$\n",
      "1.3265890750665295\n",
      "9.910991493605861\n",
      "NaiveUB    |    0.0003 |    0.0000 $+891.10 \\%$\n",
      "9.910991493605861\n",
      "146.26915443447407\n",
      "LipSDP     |   20.1609 |    2.3327 $+14526.92 \\%$\n",
      "146.26915443447407\n",
      "150.17935307747277\n",
      "SeqLip     |    0.0212 |    0.0041 $+14917.94 \\%$\n",
      "150.17935307747277\n"
     ]
    }
   ],
   "source": [
    "with open('l1_exp_2_MNIST_bin_radius02.pkl', 'rb') as f:\n",
    "    results2 = pickle.load(f)\n",
    "mnist_2_rows = format_local_globals(results2[0], results2[2], 784)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([['RandomLB  ', '$     0.330 \\\\pm      0.006$', '$-44.25 \\\\%$'],\n",
       "  ['CLEVER    ', '$     6.855 \\\\pm      4.984$', '$-38.31 \\\\%$'],\n",
       "  ['LipProblem', '$     4.550 \\\\pm      2.519$', '$0.00 \\\\%$'],\n",
       "  ['FastLip   ', '$     0.001 \\\\pm      0.000$', '$+43.45 \\\\%$'],\n",
       "  ['LipLP     ', '$     0.229 \\\\pm      0.021$', '$+43.45 \\\\%$'],\n",
       "  ['NaiveUB   ', '$     0.000 \\\\pm      0.000$', '$+961.31 \\\\%$'],\n",
       "  ['LipSDP    ', '$    18.184 \\\\pm      1.935$', '$+16147.15 \\\\%$'],\n",
       "  ['SeqLip    ', '$     0.013 \\\\pm      0.004$', '$+16559.65 \\\\%$']],\n",
       " [['RandomLB  ', '$     0.325 \\\\pm      0.005$', '$-35.08 \\\\%$'],\n",
       "  ['CLEVER    ', '$    23.010 \\\\pm      0.612$', '$-27.24 \\\\%$'],\n",
       "  ['LipProblem', '$     4.292 \\\\pm      1.183$', '$0.00 \\\\%$'],\n",
       "  ['LipLP     ', '$     0.233 \\\\pm      0.012$', '$+32.66 \\\\%$'],\n",
       "  ['FastLip   ', '$     0.002 \\\\pm      0.001$', '$+32.66 \\\\%$'],\n",
       "  ['NaiveUB   ', '$     0.000 \\\\pm      0.000$', '$+891.10 \\\\%$'],\n",
       "  ['LipSDP    ', '$    20.161 \\\\pm      2.333$', '$+14526.92 \\\\%$'],\n",
       "  ['SeqLip    ', '$     0.021 \\\\pm      0.004$', '$+14917.94 \\\\%$']])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist_1_rows, mnist_2_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_rows = []\n",
    "for el1, el2 in zip(mnist_1_rows, mnist_2_rows):\n",
    "    total_rows.append(el1 + el2[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['RandomLB  ',\n",
       "  '$     0.330 \\\\pm      0.006$',\n",
       "  '$-44.25 \\\\%$',\n",
       "  '$     0.325 \\\\pm      0.005$',\n",
       "  '$-35.08 \\\\%$'],\n",
       " ['CLEVER    ',\n",
       "  '$     6.855 \\\\pm      4.984$',\n",
       "  '$-38.31 \\\\%$',\n",
       "  '$    23.010 \\\\pm      0.612$',\n",
       "  '$-27.24 \\\\%$'],\n",
       " ['LipProblem',\n",
       "  '$     4.550 \\\\pm      2.519$',\n",
       "  '$0.00 \\\\%$',\n",
       "  '$     4.292 \\\\pm      1.183$',\n",
       "  '$0.00 \\\\%$'],\n",
       " ['FastLip   ',\n",
       "  '$     0.001 \\\\pm      0.000$',\n",
       "  '$+43.45 \\\\%$',\n",
       "  '$     0.233 \\\\pm      0.012$',\n",
       "  '$+32.66 \\\\%$'],\n",
       " ['LipLP     ',\n",
       "  '$     0.229 \\\\pm      0.021$',\n",
       "  '$+43.45 \\\\%$',\n",
       "  '$     0.002 \\\\pm      0.001$',\n",
       "  '$+32.66 \\\\%$'],\n",
       " ['NaiveUB   ',\n",
       "  '$     0.000 \\\\pm      0.000$',\n",
       "  '$+961.31 \\\\%$',\n",
       "  '$     0.000 \\\\pm      0.000$',\n",
       "  '$+891.10 \\\\%$'],\n",
       " ['LipSDP    ',\n",
       "  '$    18.184 \\\\pm      1.935$',\n",
       "  '$+16147.15 \\\\%$',\n",
       "  '$    20.161 \\\\pm      2.333$',\n",
       "  '$+14526.92 \\\\%$'],\n",
       " ['SeqLip    ',\n",
       "  '$     0.013 \\\\pm      0.004$',\n",
       "  '$+16559.65 \\\\%$',\n",
       "  '$     0.021 \\\\pm      0.004$',\n",
       "  '$+14917.94 \\\\%$']]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_rows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['RandomLB  ', '$     0.325 \\\\pm      0.005$', '$-35.08 \\\\%$']\n",
      "['CLEVER    ', '$    23.010 \\\\pm      0.612$', '$-27.24 \\\\%$']\n",
      "['LipProblem', '$     4.292 \\\\pm      1.183$', '$0.00 \\\\%$']\n",
      "['LipLP     ', '$     0.233 \\\\pm      0.012$', '$+32.66 \\\\%$']\n",
      "['FastLip   ', '$     0.002 \\\\pm      0.001$', '$+32.66 \\\\%$']\n",
      "['NaiveUB   ', '$     0.000 \\\\pm      0.000$', '$+891.10 \\\\%$']\n",
      "['LipSDP    ', '$    20.161 \\\\pm      2.333$', '$+14526.92 \\\\%$']\n",
      "['SeqLip    ', '$     0.021 \\\\pm      0.004$', '$+14917.94 \\\\%$']\n"
     ]
    }
   ],
   "source": [
    "for el in mnist_rows:\n",
    "    print(el)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_rows_to_csv(rows, filename):\n",
    "    with open(filename, 'w', newline='') as f:\n",
    "        csvwriter = csv.writer(f, delimiter=',',\n",
    "                                quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
    "        for row in rows:\n",
    "            csvwriter.writerow(row)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "write_rows_to_csv(total_rows, 'l1_exp_2_MNIST_DATA_ERR.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "def scrub_results(results):\n",
    "    output = []\n",
    "    for _ in results.results:\n",
    "        if (_.values('CLEVER')) < 10000:\n",
    "            output.append(_)\n",
    "    new_results = ResultList(output)\n",
    "    return new_results\n",
    "synthetic_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "def csv_parser(results, dim, csvfile, no_err=False):\n",
    "    times = results.average_stdevs('time')\n",
    "    values = results.average_stdevs('value')\n",
    "    rel_errs = results.get_rel_err(dim)\n",
    "    keys = sorted(list(times.keys()))\n",
    "    max_len_k = max(len(_) for _ in keys + ['method'])\n",
    "    rows = []\n",
    "    pad = lambda s: s + ' ' * (max_len_k - len(s))\n",
    "    def dim_scale(k, val):\n",
    "        if k not in ['SeqLip', 'LipSDP']:\n",
    "            return val\n",
    "        else:\n",
    "            return math.sqrt(dim) *val\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    row1 = ['Method', 'Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD']\n",
    "    rows.append(row1)\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-' * len(header))\n",
    "    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]\n",
    "    for k in key_order:\n",
    "        elements = [pad(k),  \n",
    "                    '${:10.3f} \\pm {:10.3f}$'.format(times[k][0], times[k][1]),\n",
    "                    '${:10.3f} \\pm {:10.3f}$'.format(dim_scale(k, values[k][0]), dim_scale(k, values[k][1]))]\n",
    "        if not no_err:\n",
    "            elements.append(\n",
    "                    '${:10.2f}\\% \\pm {:10.2f}\\%$'.format(rel_errs[k][0] * 100, rel_errs[k][1] * 100))\n",
    "        rows.append(elements)\n",
    "        print(' '.join(elements))\n",
    "    return rows\n",
    "    with open(csvfile, 'w', newline='') as f:\n",
    "        csvwriter = csv.writer(f, delimiter=',',\n",
    "                                quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
    "        for row in rows:\n",
    "            csvwriter.writerow(row)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls *.pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_file = 'exp_1_synthetic.pkl'\n",
    "with open(synthetic_file, 'rb') as f:\n",
    "    synthetic_results = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "synthetic_rows = csv_parser(scrub_results(synthetic_results), 10, None, no_err=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_rows = csv_parser(new_rand_results, 16, None, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('exp_1_MNIST_bin.pkl', 'rb') as f:\n",
    "    mnist_results = scrub_results(pickle.load(f))\n",
    "mnist_rows = csv_parser(mnist_results, 784, None, True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(random_rows), len(synthetic_rows), len(mnist_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_rows = []\n",
    "for i in range(1, 9):\n",
    "   all_rows.append(random_rows[i] + synthetic_rows[i][1:] + mnist_rows[i][1:]) \n",
    "\n",
    "\n",
    "rand_synth_rows = []\n",
    "for i in range(1, 9):\n",
    "    rand_synth_rows.append(random_rows[i] + synthetic_rows[i][1:])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('experiment_1_randSynth.csv', 'w', newline='') as f:\n",
    "    csvwriter = csv.writer(f, delimiter=',',\n",
    "                                quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
    "    for row in rand_synth_rows:\n",
    "        csvwriter.writerow(row)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('experiment_1_MNIST.csv', 'w', newline='') as f:\n",
    "    csvwriter = csv.writer(f, delimiter=',',\n",
    "                                quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
    "    for row in mnist_rows:\n",
    "        csvwriter.writerow(row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "csv_parser(new_rand_results, 16, 'exp1RANDOMCSV.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
