{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rcwc import *\n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_data(k):\n",
    "  '''\n",
    "  Given k, build the selective prediction instance corresponding to n=2^k.\n",
    "  '''\n",
    "  n = 2**k\n",
    "  m = n*(k-2) + k + 2\n",
    "  A = np.ones((m,n))\n",
    "  a_sp = np.zeros((m,n))\n",
    "  b = np.zeros((m,n))\n",
    "  weights = np.zeros(m)\n",
    "  ctr = 0\n",
    "  for i in range(k):\n",
    "    w = 2**i\n",
    "    for t in range(w, n-w+1):\n",
    "      A[ctr,t::] = 0\n",
    "      b[ctr,t:t+w] = 1\n",
    "      a_sp[ctr,t-w:t] = 1\n",
    "      weights[ctr] = 1 / (k*(n + 1 - 2**(i+1)))\n",
    "      ctr += 1\n",
    "  b = b / np.sum(b, axis=1, keepdims=True)\n",
    "  a_sp = a_sp / np.sum(a_sp, axis=1, keepdims=True)\n",
    "  \n",
    "  return A, b, weights, a_sp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "max_k = 6\n",
    "weights_rcwc = []\n",
    "weights_sp = []\n",
    "rcwc_errors = np.zeros(max_k)\n",
    "sp_errors = np.zeros(max_k)\n",
    "\n",
    "for k in range(1, max_k+1):\n",
    "  print()\n",
    "  print(k)\n",
    "  A, b, weights, a_sp = build_data(k)\n",
    "  a_rcwc = rcwc(A, b, weights=weights, is_verbose=True)\n",
    "  weights_rcwc.append(a_rcwc)\n",
    "  weights_sp.append(a_sp)\n",
    "  rcwc_errors[k-1] = evaluate_weights_grothendieck(a_rcwc, b, weights=weights)\n",
    "  sp_errors[k-1] = evaluate_weights_grothendieck(a_sp, b, weights=weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting results\n",
    "from matplotlib import cm\n",
    "import matplotlib\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 30})\n",
    "plt.figure(figsize=(8,7))\n",
    "plt.plot(2 ** np.arange(2,max_k+1), sp_errors[1:], color='blue', linewidth=4, marker='o', markersize=10, alpha=0.8)\n",
    "plt.plot(2 ** np.arange(2,max_k+1), rcwc_errors[1:], color='green', linewidth=4, marker='D', markersize=10, alpha=0.8)\n",
    "plt.yscale('log')\n",
    "plt.xscale('log', basex=2)\n",
    "ax = plt.gca()\n",
    "ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.2f'))\n",
    "ax.set_yticks([0.30, 1.00, 3.00])\n",
    "ax.set_ylim([0.15,3.5])\n",
    "ax.xaxis.set_major_formatter(ScalarFormatter())\n",
    "ax.set_xticks([4, 8, 16, 32, 64])\n",
    "plt.legend(('Selective Prediction', 'RCWC'), frameon=False, loc=(0,0))\n",
    "plt.ylabel('Prediction Error (Log Scale)')\n",
    "plt.xlabel('n (Log Scale)')\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
