{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import pandas as pd\n",
    "from sklearn.externals import joblib\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['icml' 'random' 'sgd_all' 'sgd_last'] ['ae' 'iso'] [    1     3     6    10    30    60   100   300   600  1000  3000  6000\n",
      " 10000]\n"
     ]
    }
   ],
   "source": [
    "target = 'mnist'\n",
    "#target = 'cifar10'\n",
    "#start_epoch = 0\n",
    "start_epoch = 19\n",
    "end_epoch = 20\n",
    "seed = 0\n",
    "\n",
    "res = joblib.load('./%s/%s_%02d/eval_epoch_%02d_to_%02d.dat' % (target, target, seed, start_epoch, end_epoch))\n",
    "res2 = joblib.load('./%s/%s_%02d/eval_epoch_%02d_to_%02d_outlier.dat' % (target, target, seed, start_epoch, end_epoch))\n",
    "res.pop('baseline')\n",
    "res2.pop('baseline')\n",
    "methods = np.sort(np.unique([key[0] for key in res.keys()]))\n",
    "methods2 = np.sort(np.unique([key[0] for key in res2.keys()]))\n",
    "ks = np.sort(np.unique([key[1] for key in res.keys()]))\n",
    "print(methods, methods2, ks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(13, 7, 30)\n"
     ]
    }
   ],
   "source": [
    "n = 30\n",
    "\n",
    "acc = np.zeros((ks.size, methods.size+methods2.size+1, n))\n",
    "t = 5\n",
    "for seed in range(n):\n",
    "    try:\n",
    "        res = joblib.load('./%s/%s_%02d/eval_epoch_%02d_to_%02d.dat' % (target, target, seed, start_epoch, end_epoch))\n",
    "        res2 = joblib.load('./%s/%s_%02d/eval_epoch_%02d_to_%02d_outlier.dat' % (target, target, seed, start_epoch, end_epoch))\n",
    "        acc[:, 0, seed] = res.pop('baseline')[t]\n",
    "        for i, m in enumerate(methods):\n",
    "            for j, k in enumerate(ks):\n",
    "                acc[j, i+1, seed] = res[(m, k)][t]\n",
    "        for i, m in enumerate(methods2):\n",
    "            for j, k in enumerate(ks):\n",
    "                acc[j, methods.size+i+1, seed] = res2[(m, k)][t]\n",
    "    except:\n",
    "        continue\n",
    "\n",
    "idx = np.where(np.min(acc, axis=(0,1)))[0]\n",
    "acc = acc[:, :, idx]\n",
    "print(acc.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.semilogx(ks, 1-np.mean(acc, axis=2))\n",
    "plt.legend(['baseline', *methods.tolist(), *methods2.tolist()])\n",
    "plt.xlabel('# of instances removed')\n",
    "plt.ylabel('missclassification rate')\n",
    "if target == 'mnist':\n",
    "    plt.ylim([0.007, 0.011])\n",
    "elif target == 'cifar10':\n",
    "    plt.ylim([0.155, 0.175])\n",
    "plt.title('%s' % (target,))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
