{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.externals import joblib\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = 'adult'\n",
    "#target = '20news'\n",
    "#target = 'mnist'\n",
    "model = 'logreg'\n",
    "#model = 'dnn'\n",
    "seed = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dn = './%s_%s' % (target, model)\n",
    "infl_true = joblib.load('%s/loss_diff_true_%03d.dat' % (dn, seed))\n",
    "infl_sgd = joblib.load('%s/loss_diff_proposed_%03d.dat' % (dn, seed))\n",
    "if model == 'dnn':\n",
    "    infl_sgd = infl_sgd[:, -1]\n",
    "infl_icml = joblib.load('%s/loss_diff_icml_%03d.dat' % (dn, seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(infl_true, infl_icml, 'rs', ms=6)\n",
    "plt.plot(infl_true, infl_sgd, 'bo', ms=6)\n",
    "axes = plt.gca()\n",
    "xlim = axes.get_xlim()\n",
    "ylim = axes.get_ylim()\n",
    "lim = [min(xlim[0], ylim[0]), max(xlim[1], ylim[1])]\n",
    "plt.plot(lim, lim, 'k--')\n",
    "plt.axis('square')\n",
    "plt.xlabel('True Linear Influence')\n",
    "plt.ylabel('Estiamted Influence')\n",
    "plt.legend(['K&L', 'Proposed'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
