{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "import numpy as np \n",
    "import torch\n",
    "from hyperbox import Hyperbox\n",
    "from interval_analysis import HBoxIA\n",
    "from relu_nets import ReLUNet\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, SeqLip\n",
    "from neural_nets import train\n",
    "from neural_nets import data_loaders as dl\n",
    "from experiment import Experiment, InstanceGroup, Result, ResultList, MethodNest\n",
    "from utilities import Factory\n",
    "import math\n",
    "import neural_nets.data_loaders as dl \n",
    "import neural_nets.train as train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "box_methods = [CLEVER, FastLip, LipLP, LipSDP, NaiveUB, RandomLB, LipProblem, SeqLip]\n",
    "local_methods = [LipProblem, LipLP, FastLip, RandomLB, CLEVER]\n",
    "global_methods = [SeqLip, LipSDP, NaiveUB]\n",
    "NUM_THREADS = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dim_scale(k, val, dim):\n",
    "    if k not in ['SeqLip', 'LipSDP']:\n",
    "        return val\n",
    "    else:\n",
    "        return math.sqrt(dim) *val\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def test_random(layer_sizes, k_trials, num_random, radius):\n",
    "    # Good parameters are [16, 16, 16, 2]\n",
    "    assert layer_sizes[-1] == 2\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    local_results = []\n",
    "    global_results = []\n",
    "    for _ in range(k_trials):\n",
    "        random_net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        local_exp = Experiment(local_methods, network=random_net, c_vector=c_vector, primal_norm='linf', \n",
    "                               verbose=True, num_threads=NUM_THREADS)\n",
    "        local_results.append(local_exp.do_random_evals(num_random_points=num_random, \n",
    "                                sample_domain=sample_domain, ball_factory=ball_factory))\n",
    "        global_exp = Experiment(global_methods, network=random_net, c_vector=c_vector, primal_norm='linf')\n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return local_results, global_results\n",
    "\n",
    "\n",
    "\n",
    "#output = test_random([16, 16, 16, 2], 2, 3, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#format_resultList(ResultList(output[0][0].results + output[0][1].results), 16)\n",
    "\n",
    "#format_resultList(ResultList(output[1]), 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_local_globals(result_list_list, global_result, dim):\n",
    "    num_per = [len(_.results) for _ in result_list_list]\n",
    "    \n",
    "    all_results = [_.results for _ in result_list_list]\n",
    "    local_rl = ResultList([_ for el in all_results for _ in el])\n",
    "    global_rl = ResultList(global_result)\n",
    "    local_times = local_rl.average_stdevs('time')\n",
    "    local_vals = local_rl.average_stdevs('value')\n",
    "    global_times = global_rl.average_stdevs('time')\n",
    "    global_vals = global_rl.average_stdevs('value')\n",
    "    local_errs = local_rl.get_rel_err(dim)\n",
    "    all_keys = list(local_times.keys()) + list(global_times.keys())\n",
    "    \n",
    "    # Get the values total dict:\n",
    "    all_values = {}\n",
    "    for val_dict in [local_vals, global_vals]:\n",
    "        for k,v in val_dict.items():\n",
    "            all_values[k] = (dim_scale(k,v[0], dim), dim_scale(k, v[1],dim))\n",
    "    # Get times total dict:\n",
    "    all_times = {}\n",
    "    for time_dict in [local_times, global_times]:\n",
    "        for k,v in time_dict.items():\n",
    "            all_times[k] = v\n",
    "   \n",
    "    # Get right answers in right order:\n",
    "    right_answers = [_.values('LipProblem') for _ in local_rl.results]\n",
    "    rel_err_dict = {}\n",
    "    \n",
    "    for k in global_vals.keys():\n",
    "        answer_idx = 0\n",
    "\n",
    "        if k not in rel_err_dict:\n",
    "            rel_err_dict[k] = []\n",
    "        for num, result in zip(num_per, global_rl.results):\n",
    "            for i in range(num):\n",
    "                right_answer = right_answers[answer_idx]\n",
    "                answer_idx +=1 \n",
    "                rel_err_dict[k].append(dim_scale(k, result.values(k), dim) / right_answer)\n",
    "    global_errs = {k: (np.array(v).mean(), np.array(v).std(), len(v)) for k,v in rel_err_dict.items()}\n",
    "    all_errs = {}\n",
    "\n",
    "    for err_dict in global_errs, local_errs:\n",
    "        for k,v in err_dict.items():\n",
    "            all_errs[k] = v\n",
    "\n",
    "\n",
    "    # Key-order \n",
    "    key_order = [_[0] for _ in sorted(all_values.items(), key=lambda p:p[1])]\n",
    "    max_len_k = max(len(_) for _ in key_order + ['method'])\n",
    "    pad = lambda s: s + ' '*(max_len_k - len(s))\n",
    "\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-'*len(header))\n",
    "    dformat = lambda s: '|{:10.4f}'.format(s)\n",
    "    for k in key_order:\n",
    "        elements = [pad(k), \n",
    "                    dformat(all_times[k][0]), dformat(all_times[k][1]),\n",
    "                    dformat(all_values[k][0]), dformat(all_values[k][1]),\n",
    "                    dformat(all_errs[k][0]), dformat(all_errs[k][1])]\n",
    "        print(' '.join(elements))\n",
    "        \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../scripts/exp_2_MNIST_bin_radius01.pkl', 'rb') as f:\n",
    "    obj = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "RandomLB   |    0.3337 |    0.0186 |  130.6436 |   74.6569 |    0.5804 |    0.2964\n",
      "CLEVER     |   20.5742 |    4.3204 |  142.5186 |   78.0794 |    0.6303 |    0.3054\n",
      "LipProblem |   69.1874 |   70.1145 |  224.5098 |   42.6603 |    1.0000 |    0.0000\n",
      "LipLP      |    0.2258 |    0.0226 |  298.0375 |   36.1698 |    1.3939 |    0.4126\n",
      "FastLip    |    0.0015 |    0.0004 |  345.7959 |   37.3263 |    1.6341 |    0.5640\n",
      "LipSDP     |   20.5700 |    2.7531 |  434.9666 |   28.8018 |    2.1392 |    1.1975\n",
      "SeqLip     |    0.0219 |    0.0046 |  446.2232 |   30.1950 |    2.1953 |    1.2338\n",
      "NaiveUB    |    0.0003 |    0.0000 |  636.6679 |   60.9249 |    3.1268 |    1.7213\n"
     ]
    }
   ],
   "source": [
    "format_local_globals(obj[0], obj[2], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_synthetic(layer_sizes, k_trials, num_random, radius):\n",
    "    # Good parameters are [16, 16, 16, 2]    \n",
    "    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, \n",
    "                         dimension=layer_sizes[0])\n",
    "    dataset = dl.RandomDataset(data_params, batch_size=128, \n",
    "                               random_seed=1234)\n",
    "    train_set, _ = dataset.split_train_val(1.0)\n",
    "\n",
    "    train_params = train.TrainParameters(train_set, train_set, 500,\n",
    "                                         test_after_epoch=20)\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    local_results = []\n",
    "    global_results = []\n",
    "    for _ in range(k_trials):\n",
    "        net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        train.training_loop(net, train_params)\n",
    "        \n",
    "        local_exp = Experiment(local_methods, network=net, c_vector=c_vector, primal_norm='linf', \n",
    "                               verbose=True, num_threads=NUM_THREADS)\n",
    "        local_results.append(local_exp.do_random_evals(num_random_points=num_random, \n",
    "                                sample_domain=sample_domain, ball_factory=ball_factory))\n",
    "        global_exp = Experiment(global_methods, network=net, c_vector=c_vector, primal_norm='linf')\n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return local_results, global_results\n",
    "\n",
    "synth_out = test_synthetic([10, 20, 30, 20, 2], 1, 10, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(synth_out[0][0], 10)\n",
    "print({k: dim_scale(k, v,10) for k,v in synth_out[1][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_mnist_bin(layer_sizes, k_trials, num_random, radius):\n",
    "    DIGITS = [1,7]\n",
    "    # Good params here are [784, 20, 20, 2]\n",
    "    data_results = []\n",
    "    random_results = []\n",
    "    global_results = []\n",
    "    train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=True)\n",
    "    val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=True)\n",
    "    \n",
    "    xentropyReg = train.XEntropyReg()\n",
    "    l1_reg = train.LpWeightReg(scalar=1e-5, lp='l1')\n",
    "    loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])\n",
    "    train_params = train.TrainParameters(train_data, val_data, 10, loss_functional=loss)\n",
    "    cvec = np.array([1.0,-1.0])\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for i in range(k_trials):\n",
    "        # Train the net \n",
    "        net = ReLUNet(layer_sizes)\n",
    "        train.training_loop(net, train_params, use_cuda=True)\n",
    "        \n",
    "        # Do the data experiments \n",
    "        local_exp = Experiment(local_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=NUM_THREADS)\n",
    "        data_to_check = dl.load_mnist_data('val', digits=DIGITS, use_cuda=False, shuffle=True, batch_size=num_random)\n",
    "        data_to_check = next(iter(data_to_check))[0].cpu().numpy()\n",
    "        data_results.append(local_exp.do_data_evals(data_to_check, ball_factory))\n",
    "        # Do the random point experiments\n",
    "        #random_results.append(local_exp.do_random_evals(num_random, sample_domain, ball_factory ))\n",
    "        \n",
    "        # Do the global experiments\n",
    "        global_exp = Experiment(global_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=NUM_THREADS) \n",
    "        global_results.append(global_exp.do_unit_hypercube_eval())\n",
    "    return data_results, random_results, global_results\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 00 | Accuracy: 99.45\n",
      "Epoch 01 | Accuracy: 99.63\n",
      "Epoch 02 | Accuracy: 99.68\n",
      "Epoch 03 | Accuracy: 99.68\n",
      "Epoch 04 | Accuracy: 99.70\n",
      "Epoch 05 | Accuracy: 99.72\n",
      "Epoch 06 | Accuracy: 99.73\n",
      "Epoch 07 | Accuracy: 99.74\n",
      "Epoch 08 | Accuracy: 99.75\n",
      "Epoch 09 | Accuracy: 99.78\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 4321 rows, 3411 columns and 43155 nonzeros\n",
      "Model fingerprint: 0x28a3d291\n",
      "Variable types: 2597 continuous, 814 integer (814 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+01]\n",
      "  RHS range        [1e-06, 8e+00]\n",
      "Presolve removed 172 rows and 118 columns\n",
      "Presolve time: 0.05s\n",
      "Presolved: 4149 rows, 3293 columns, 35992 nonzeros\n",
      "Variable types: 2479 continuous, 814 integer (814 binary)\n",
      "\n",
      "Root relaxation: objective 2.873353e+02, 1640 iterations, 0.05 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  287.31009    0  798          -  287.31009      -     -    0s\n",
      "     0     0  245.26200    0  784          -  245.26200      -     -    0s\n",
      "     0     0  235.88186    0  780          -  235.88186      -     -    0s\n",
      "     0     0  235.69089    0  769          -  235.69089      -     -    0s\n",
      "     0     0  235.50769    0  758          -  235.50769      -     -    0s\n",
      "     0     0  235.34361    0  752          -  235.34361      -     -    0s\n",
      "     0     0  235.12522    0  744          -  235.12522      -     -    0s\n",
      "     0     0  234.95886    0  738          -  234.95886      -     -    0s\n",
      "     0     0  234.81639    0  795          -  234.81639      -     -    0s\n",
      "     0     0  234.70451    0  795          -  234.70451      -     -    0s\n",
      "     0     0  234.68341    0  795          -  234.68341      -     -    0s\n",
      "     0     0  234.67505    0  795          -  234.67505      -     -    0s\n",
      "     0     0  232.08253    0  723          -  232.08253      -     -    0s\n",
      "     0     0  231.93113    0  792          -  231.93113      -     -    0s\n",
      "     0     0  231.91312    0  796          -  231.91312      -     -    0s\n",
      "     0     0  230.24896    0  798          -  230.24896      -     -    1s\n",
      "     0     0  230.22409    0  795          -  230.22409      -     -    1s\n",
      "     0     0  229.24694    0  802          -  229.24694      -     -    1s\n",
      "     0     0  229.21078    0  803          -  229.21078      -     -    1s\n",
      "     0     0  228.97407    0  802          -  228.97407      -     -    1s\n",
      "     0     0  228.97407    0  802          -  228.97407      -     -    1s\n",
      "H    0     0                     198.7252193  228.97407  15.2%     -    3s\n",
      "     0     2  228.97386    0  802  198.72522  228.97386  15.2%     -    3s\n",
      "   560   435  214.21351  486   75  198.72522  228.95910  15.2%  12.5    5s\n",
      "*  646   416             540     222.3348275  228.95910  2.98%  14.3    5s\n",
      "H  652   394                     222.3348329  228.92205  2.96%  14.6    7s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 6\n",
      "  Cover: 339\n",
      "  Clique: 11\n",
      "  MIR: 5\n",
      "  Flow cover: 23\n",
      "  RLT: 173\n",
      "  Relax-and-lift: 440\n",
      "\n",
      "Explored 679 nodes (17980 simplex iterations) in 9.18 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 222.335 198.725 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 2.223348329240e+02, best bound 2.223375828812e+02, gap 0.0012%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 3281 rows, 3074 columns and 40399 nonzeros\n",
      "Model fingerprint: 0xc36e3da9\n",
      "Variable types: 2597 continuous, 477 integer (477 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 2e+01]\n",
      "  RHS range        [1e-06, 1e+01]\n",
      "Presolve removed 916 rows and 843 columns\n",
      "Presolve time: 0.16s\n",
      "Presolved: 2365 rows, 2231 columns, 26164 nonzeros\n",
      "Variable types: 1777 continuous, 454 integer (454 binary)\n",
      "\n",
      "Root relaxation: objective 2.212533e+02, 1175 iterations, 0.02 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  221.25325    0  435          -  221.25325      -     -    0s\n",
      "     0     0  205.22354    0  389          -  205.22354      -     -    0s\n",
      "     0     0  205.12417    0  388          -  205.12417      -     -    0s\n",
      "     0     0  205.12406    0  387          -  205.12406      -     -    0s\n",
      "     0     0  203.75762    0  405          -  203.75762      -     -    0s\n",
      "     0     0  203.72862    0  405          -  203.72862      -     -    0s\n",
      "     0     0  203.72732    0  405          -  203.72732      -     -    0s\n",
      "     0     0  202.76447    0  404          -  202.76447      -     -    0s\n",
      "     0     0  202.69633    0  404          -  202.69633      -     -    0s\n",
      "     0     0  202.69583    0  404          -  202.69583      -     -    0s\n",
      "     0     0  202.39993    0  399          -  202.39993      -     -    0s\n",
      "     0     0  202.39574    0  399          -  202.39574      -     -    0s\n",
      "     0     0  202.32420    0  380          -  202.32420      -     -    0s\n",
      "     0     0  202.32416    0  380          -  202.32416      -     -    0s\n",
      "     0     0  202.30390    0  397          -  202.30390      -     -    0s\n",
      "     0     0  202.30390    0  397          -  202.30390      -     -    0s\n",
      "     0     2  202.28966    0  397          -  202.28966      -     -    1s\n",
      "*  342    87             135     182.7019259  191.39897  4.76%  22.9    1s\n",
      "\n",
      "Cutting planes:\n",
      "  Gomory: 1\n",
      "  Cover: 74\n",
      "  Implied bound: 2\n",
      "  MIR: 55\n",
      "  Flow cover: 46\n",
      "  RLT: 63\n",
      "  Relax-and-lift: 54\n",
      "\n",
      "Explored 506 nodes (11351 simplex iterations) in 2.40 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 1: 182.702 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 1.827019258927e+02, best bound 1.827019258927e+02, gap 0.0000%\n",
      "Changed value of parameter Threads to 1\n",
      "   Prev: 0  Min: 0  Max: 1024  Default: 0\n",
      "Gurobi Optimizer version 9.0.0 build v9.0.0rc2 (linux64)\n",
      "Optimize a model with 4345 rows, 3415 columns and 43211 nonzeros\n",
      "Model fingerprint: 0xc1ec4017\n",
      "Variable types: 2597 continuous, 818 integer (818 binary)\n",
      "Coefficient statistics:\n",
      "  Matrix range     [1e-08, 1e+01]\n",
      "  Objective range  [1e+00, 1e+00]\n",
      "  Bounds range     [1e-06, 1e+01]\n",
      "  RHS range        [1e-06, 8e+00]\n",
      "Presolve removed 168 rows and 110 columns\n",
      "Presolve time: 0.07s\n",
      "Presolved: 4177 rows, 3305 columns, 36060 nonzeros\n",
      "Variable types: 2487 continuous, 818 integer (818 binary)\n",
      "\n",
      "Root relaxation: objective 2.901767e+02, 1683 iterations, 0.05 seconds\n",
      "\n",
      "    Nodes    |    Current Node    |     Objective Bounds      |     Work\n",
      " Expl Unexpl |  Obj  Depth IntInf | Incumbent    BestBd   Gap | It/Node Time\n",
      "\n",
      "     0     0  290.14805    0  802          -  290.14805      -     -    0s\n",
      "     0     0  254.68386    0  768          -  254.68386      -     -    0s\n",
      "     0     0  240.59923    0  789          -  240.59923      -     -    0s\n",
      "     0     0  240.30809    0  784          -  240.30809      -     -    0s\n",
      "     0     0  240.24078    0  779          -  240.24078      -     -    0s\n",
      "     0     0  240.13070    0  774          -  240.13070      -     -    0s\n",
      "     0     0  240.10702    0  772          -  240.10702      -     -    0s\n",
      "     0     0  239.97276    0  768          -  239.97276      -     -    0s\n",
      "     0     0  239.95049    0  766          -  239.95049      -     -    0s\n",
      "     0     0  239.82505    0  764          -  239.82505      -     -    0s\n",
      "     0     0  239.73184    0  762          -  239.73184      -     -    0s\n",
      "     0     0  239.67244    0  760          -  239.67244      -     -    0s\n",
      "     0     0  239.64649    0  799          -  239.64649      -     -    0s\n",
      "     0     0  239.61002    0  799          -  239.61002      -     -    0s\n",
      "     0     0  239.60527    0  799          -  239.60527      -     -    0s\n",
      "     0     0  236.66496    0  757          -  236.66496      -     -    0s\n",
      "     0     0  236.54488    0  781          -  236.54488      -     -    0s\n",
      "     0     0  236.53574    0  788          -  236.53574      -     -    0s\n",
      "     0     0  234.83626    0  780          -  234.83626      -     -    1s\n",
      "     0     0  234.80365    0  777          -  234.80365      -     -    1s\n",
      "     0     0  234.28624    0  804          -  234.28624      -     -    1s\n",
      "     0     0  234.25644    0  804          -  234.25644      -     -    1s\n",
      "     0     0  234.22836    0  768          -  234.22836      -     -    1s\n",
      "     0     0  234.22836    0  768          -  234.22836      -     -    1s\n",
      "H    0     0                      65.5744005  234.22836   257%     -    3s\n",
      "     0     2  234.22504    0  768   65.57440  234.22504   257%     -    3s\n",
      "   493   425     cutoff  449        65.57440  234.19547   257%  13.4    5s\n",
      "*  873   619             749     222.3348215  234.19547  5.33%  34.5    7s\n",
      "   880   618  227.63425  182  832  222.33482  234.18753  5.33%  34.3   10s\n",
      "   963   673  227.88774   81  863  222.33482  227.88774  2.50%  31.4   15s\n",
      "  1021   712  226.04894  262  862  222.33482  227.77222  2.45%  29.6   20s\n",
      "  1045   728  222.42034  655  866  222.33482  227.37246  2.27%  28.9   25s\n",
      "  1069   745  223.16065  108  113  222.33482  223.16065  0.37%  42.9   31s\n",
      "\n",
      "Explored 1074 nodes (55656 simplex iterations) in 31.62 seconds\n",
      "Thread count was 1 (of 8 available processors)\n",
      "\n",
      "Solution count 2: 222.335 65.5744 \n",
      "\n",
      "Optimal solution found (tolerance 1.00e-04)\n",
      "Best objective 2.223348223194e+02, best bound 2.223348223194e+02, gap 0.0000%\n"
     ]
    }
   ],
   "source": [
    "mnist_out = test_mnist_bin([784, 20, 20, 20, 2], 1, 3, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Method     |      Time |  Time STD |     Value | Value STD |       Err |    ErrSTD\n",
      "----------------------------------------------------------------------------------\n",
      "RandomLB   |    0.3341 |    0.0036 |  113.3396 |   60.3840 |    0.5324 |    0.2566\n",
      "CLEVER     |    5.2868 |    0.0350 |  118.1102 |   67.1299 |    0.5538 |    0.2869\n",
      "LipProblem |   14.5292 |   12.4994 |  209.1239 |   18.6831 |    1.0000 |    0.0000\n",
      "LipLP      |    0.2050 |    0.0355 |  267.0902 |   30.7458 |    1.2740 |    0.0357\n",
      "FastLip    |    0.0014 |    0.0000 |  313.4928 |   30.0406 |    1.4981 |    0.0105\n",
      "LipSDP     |   17.2346 |    0.0000 |  397.3521 |    0.0000 |    1.9164 |    0.1828\n",
      "SeqLip     |    0.0124 |    0.0000 |  403.4369 |    0.0000 |    1.9458 |    0.1856\n",
      "NaiveUB    |    0.0003 |    0.0000 |  619.9666 |    0.0000 |    2.9901 |    0.2851\n",
      " \n"
     ]
    }
   ],
   "source": [
    "format_local_globals(mnist_out[0], mnist_out[2], 784)\n",
    "\n",
    "print(' ')\n",
    "#format_local_globals(mnist_out[1], mnist_out[2], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mnist_out)\n",
    "format_resultList(mnist_out[0][0], 784)\n",
    "print('\\n')\n",
    "format_resultList(mnist_out[1][0], 784)\n",
    "print({k: dim_scale(k, v,784) for k,v in mnist_out[2][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_out[0][0], 784)\n",
    "print('\\n')\n",
    "format_resultList(mnist_out[1][0], 784)\n",
    "print()\n",
    "print({k: dim_scale(k, v, 784) for k,v in mnist_out[2][0].values().items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(output[0][0], 16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(output[0], 16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_synthetic(layer_sizes, k_trials, num_random, radius):\n",
    "    # exp1 layers: 16,16,16,2\n",
    "    data_params = dl.RandomKParameters(num_points=2000, k=20, radius=0.02, \n",
    "                         dimension=layer_sizes[0])\n",
    "    dataset = dl.RandomDataset(data_params, batch_size=128, \n",
    "                               random_seed=1234)\n",
    "    train_set, _ = dataset.split_train_val(1.0)\n",
    "\n",
    "    train_params = train.TrainParameters(train_set, train_set, 500,\n",
    "                                         test_after_epoch=20)\n",
    "    results = []\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for _ in range(k_trials):\n",
    "        net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        train.training_loop(net, train_params)\n",
    "        exp = Experiment(local_methods, network=net, c_vector=c_vector, primal_norm='linf', verbose=True, num_threads=4)\n",
    "        results.append(exp.do_random_evals(num_random, sample_domain, ball_factory))\n",
    "        print('\\n')\n",
    "    return results\n",
    "\n",
    "\n",
    "\n",
    "synth_out = test_synthetic([10, 20, 30, 20, 2], 2, 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(synth_out[1], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_random(layer_sizes, k, radius):\n",
    "    assert layer_sizes[-1] == 2\n",
    "    c_vector = np.array([1.0, -1.0])\n",
    "    results = []\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for _ in range(k):\n",
    "        \n",
    "        random_net = ReLUNet(layer_sizes=layer_sizes)\n",
    "        exp = Experiment(local_methods, network=random_net, c_vector=c_vector, primal_norm='linf', verbose=True)\n",
    "        results.append(exp.do_random_evals(5, sample_domain, ball_factory))\n",
    "    return ResultList(results)\n",
    "output = test_random([16, 16, 16, 2], 2, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output.results[0].average_stdevs('value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_resultList(results, dim):\n",
    "    times = results.average_stdevs('time')\n",
    "    values = results.average_stdevs('value')\n",
    "    rel_errs = results.get_rel_err(dim)\n",
    "    keys = sorted(list(times.keys()))\n",
    "    max_len_k = max(len(_) for _ in keys + ['method'])\n",
    "\n",
    "    pad = lambda s: s + ' ' * (max_len_k - len(s))\n",
    "    def dim_scale(k, val):\n",
    "        if k not in ['SeqLip', 'LipSDP']:\n",
    "            return val\n",
    "        else:\n",
    "            return math.sqrt(dim) *val\n",
    "    header_pad = lambda s: '|' + ' ' * (10 - len(s)) + s\n",
    "    header = pad('Method') +' '+ ' '.join(header_pad(_) for _ in ['Time', 'Time STD', 'Value', 'Value STD', 'Err', 'ErrSTD'])\n",
    "    print(header + '\\n' + '-' * len(header))\n",
    "    key_order = [_[0] for _ in sorted([(k, dim_scale(k, values[k][0])) for k in keys], key=lambda p: p[1])]\n",
    "    for k in key_order:\n",
    "        elements = [pad(k),  \n",
    "                    '|{:10.4f}'.format(times[k][0]),\n",
    "                    '|{:10.4f}'.format(times[k][1]),\n",
    "                    '|{:10.4f}'.format(dim_scale(k, values[k][0])),\n",
    "                    '|{:10.4f}'.format(dim_scale(k, values[k][1]))]\n",
    "        if rel_errs != {}:\n",
    "            elements.extend(['|{:10.4f}'.format(rel_errs[k][0] * 100),\n",
    "                             '|{:10.4f}'.format(rel_errs[k][1] * 100)])\n",
    "        print(' '.join(elements))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_cvec():\n",
    "    idx_1 = np.random.randint(10)\n",
    "    while True:\n",
    "        idx_2 = np.random.randint(10)\n",
    "        if idx_2 != idx_1: \n",
    "            break \n",
    "    output = np.zeros(10)\n",
    "    output[idx_1] = 1.0\n",
    "    output[idx_2] = 1.0\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "DIGITS = None\n",
    "USE_CUDA = True\n",
    "train_data = dl.load_mnist_data('train', digits=DIGITS, use_cuda=USE_CUDA)\n",
    "val_data = dl.load_mnist_data('val', digits=DIGITS, use_cuda=USE_CUDA)\n",
    "xentropyReg = train.XEntropyReg()\n",
    "l1_reg = train.LpWeightReg(scalar=1e-3, lp='l1')\n",
    "loss = train.LossFunctional(regularizers=[xentropyReg, l1_reg])\n",
    "train_params = train.TrainParameters(train_data, val_data, 50, loss_functional=loss, test_after_epoch=5)\n",
    "train_batch_1 = next(iter(train_data))[0].view(-1, 784)\n",
    "    \n",
    "def test_mnist(layer_sizes, k, radius):\n",
    "    results = []\n",
    "    sample_domain = Hyperbox.build_unit_hypercube(layer_sizes[0])\n",
    "    ball_factory = Factory(Hyperbox.build_linf_ball, radius=radius)\n",
    "    for i in range(k):\n",
    "        net = ReLUNet(layer_sizes)\n",
    "        train.training_loop(net, train_params, use_cuda=USE_CUDA)\n",
    "        cvec = random_cvec()\n",
    "        exp = Experiment(local_methods, network=net, c_vector=cvec, primal_norm='linf', verbose=True, num_threads=4) \n",
    "        results.append(exp.do_data_evals(train_batch_1, ball_factory, num_random=5))\n",
    "    return results\n",
    "mnist_test1= test_mnist([784, 20, 20, 20, 10], 1, 0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_test1[0], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "format_resultList(mnist_test1[0], 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rl = ResultList(mnist_test1)\n",
    "\n",
    "def dim_scale(k, val, dim):\n",
    "    if k not in ['SeqLip', 'LipSDP']:\n",
    "        return val\n",
    "    else:\n",
    "        return math.sqrt(dim) *val\n",
    "        \n",
    "def errors(result):\n",
    "    vals = result.values()\n",
    "    true = vals['LipProblem']\n",
    "    return sorted([(k, dim_scale(k, v, 784) / true) for k, v in vals.items()], key=lambda p: p[1])\n",
    "outputs = []\n",
    "for result in rl.results:\n",
    "    outputs.append(errors(result))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_outs = {}\n",
    "for output in outputs:\n",
    "    for k,v in output:\n",
    "        if k not in dict_outs:\n",
    "            dict_outs[k] = []\n",
    "        dict_outs[k].append(v)\n",
    "avg_outs = {}\n",
    "for k, v in dict_outs.items():\n",
    "    avg_outs[k] = sum(v) / len(v)\n",
    "avg_outs    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
