{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rcwc import *\n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "from sklearn.neighbors import NearestNeighbors as kNN\n",
    "from matplotlib import pyplot as plt\n",
    "import os, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "import matplotlib\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(94)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# RCWC estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating samples\n",
    "n = 50 # size of samples\n",
    "m = 7500 # num samples\n",
    "p = np.ones(n)\n",
    "p[0:25] = 0.1\n",
    "p[25:50] = 0.5\n",
    "\n",
    "A = np.zeros((m,n))\n",
    "b_rcwc = np.ones((m,n)) * 1/n # target = population\n",
    "\n",
    "for i in range(m):\n",
    "  r = np.random.random(size=n)\n",
    "  A[i, r < p] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# RCWC estimator\n",
    "solveroptions = {}\n",
    "a_rcwc = rcwc(A, b_rcwc, solveroptions=solveroptions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print('Worst-case error: {}'.format(evaluate_weights_grothendieck(a_rcwc, b_rcwc, is_verbose=False)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Importance sampling estimators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating samples\n",
    "n = 50 # size of samples\n",
    "m = 1000000 # num samples\n",
    "p = np.ones(n)\n",
    "p[0:25] = 0.1\n",
    "p[25:50] = 0.5\n",
    "\n",
    "A = np.zeros((m,n))\n",
    "b = np.ones((m,n)) * 1/n # target = population\n",
    "\n",
    "for i in range(m):\n",
    "  r = np.random.random(size=n)\n",
    "  A[i, r < p] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    " # Basic estimator (sample mean)\n",
    "denom = np.sum(A, axis=1, keepdims=True)\n",
    "denom[denom == 0] = 1\n",
    "a_mean = A / denom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Importance sampling\n",
    "a_imp = A / n * (1/p)\n",
    "print(a_imp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Importance sampling v2\n",
    "a_imp2 = np.zeros((m,n))\n",
    "denom = np.sum(A[:,:25], axis=1, keepdims=True)\n",
    "denom[denom == 0] = 1\n",
    "a_imp2[:,:25] = A[:,:25] * 1/denom * 1/2\n",
    "a_imp2[:,25:50] = A[:,25:50] * 1/np.sum(A[:,25:50], axis=1, keepdims=True) * 1/2\n",
    "print(a_imp2[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fixed datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = np.ones(n)\n",
    "x2 = np.ones(n)\n",
    "x2[:25] = -1\n",
    "x3 = np.ones(n)\n",
    "x3[:12] = -1\n",
    "x3[25:38] = -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(evaluate_weights(a_mean, b, x1))\n",
    "print(evaluate_weights(a_mean, b, x2))\n",
    "print(evaluate_weights(a_mean, b, x3))\n",
    "print(evaluate_weights_grothendieck(a_mean, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(evaluate_weights(a_imp, b, x1))\n",
    "print(evaluate_weights(a_imp, b, x2))\n",
    "print(evaluate_weights(a_imp, b, x3))\n",
    "print(evaluate_weights_grothendieck(a_imp, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(evaluate_weights(a_imp2, b, x1))\n",
    "print(evaluate_weights(a_imp2, b, x2))\n",
    "print(evaluate_weights(a_imp2, b, x3))\n",
    "print(evaluate_weights_grothendieck(a_imp2, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(evaluate_weights(a_rcwc, b_rcwc, x1))\n",
    "print(evaluate_weights(a_rcwc, b_rcwc, x2))\n",
    "print(evaluate_weights(a_rcwc, b_rcwc, x3))\n",
    "print(evaluate_weights_grothendieck(a_rcwc, b_rcwc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
