{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matlab.engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from relu_nets import ReLUNet\n",
    "from experiment import Experiment, MethodNest, Job\n",
    "from neural_nets import train \n",
    "from neural_nets import data_loaders as dl\n",
    "from lipMIP import LipProblem\n",
    "from other_methods import CLEVER, SeqLip, FastLip, LipLP\n",
    "from utilities import Factory\n",
    "from hyperbox import Hyperbox\n",
    "import numpy as np\n",
    "import torch\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First make a network \n",
    "\n",
    "test_net = ReLUNet(layer_sizes=[2, 10, 10, 2])\n",
    "class_list = [LipProblem]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arguments for random\n",
    "sample_domain = Hyperbox.build_unit_hypercube(2)\n",
    "ball_factory = Factory(Hyperbox.build_linf_ball, radius=0.2, \n",
    "                       global_lo=np.array([0.0, 0.0]), global_hi=np.array([1.0, 1.0]))\n",
    "arg_bundle = {'num_random_points': 100, 'ball_factory': ball_factory, 'sample_domain': sample_domain}\n",
    "random_method = MethodNest(Experiment.do_random_evals, arg_bundle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arguments for Linf ball \n",
    "unit_method = MethodNest(Experiment.do_unit_hypercube_eval, {})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arguments for data \n",
    "data_params = dl.RandomKParameters(100, 10, radius=0.01)\n",
    "loader_kwargs = {'batch_size': 50, 'random_seed': 420}\n",
    "data_arg_bundle = {'data_type': 'synthetic', \n",
    "                   'params': data_params, \n",
    "                   'loader_kwargs': loader_kwargs,\n",
    "                   'ball_factory': ball_factory}\n",
    "data_method = MethodNest(Experiment.do_data_evals, data_arg_bundle)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experiment object \n",
    "exp = Experiment(class_list, network=test_net, c_vector=torch.tensor([1.0, -1.0]), primal_norm='linf')\n",
    "#result_out = exp.do_random_evals(**arg_bundle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "job = Job('test_job2', exp, [unit_method], tag='foobar')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = job.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "job.write()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "job2 = Job.from_file('test_job2.job')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'do_unit_hypercube_eval': <experiment.Result at 0x7f5f06913dd8>,\n",
       " 'Job': <experiment.Job at 0x7f5f068e30f0>}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "job2.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = job2.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "lipresult = output['do_unit_hypercube_eval'].objects()['LipProblem']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sign_config': [array([ True, False, False, False, False,  True, False, False, False,\n",
       "         False]),\n",
       "  array([False,  True, False,  True, False,  True, False, False,  True,\n",
       "         False])],\n",
       " 'best_x': array([0.41935957, 0.25034228]),\n",
       " 'c_vector': tensor([ 1., -1.]),\n",
       " 'value': 0.06319125435726702,\n",
       " 'compute_time': 0.041378021240234375,\n",
       " 'domain': <hyperbox.Hyperbox at 0x7f5f068eb278>}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lipresult.as_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('wtf.pkl', 'wb') as g:\n",
    "    pickle.dump(lipresult, g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('wtf.pkl', 'rb') as h:\n",
    "    output = (pickle.load(h))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([ True, False, False, False, False,  True, False, False, False,\n",
       "        False]),\n",
       " array([False,  True, False,  True, False,  True, False, False,  True,\n",
       "        False])]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.sign_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('test_job2.result', 'rb') as f:\n",
    "    pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(output['do_random_evals'].average_stdevs('value'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls *.result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
