{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rcwc import *\n",
    "import numpy as np\n",
    "import cvxpy as cp\n",
    "from sklearn.neighbors import NearestNeighbors as kNN\n",
    "from matplotlib import pyplot as plt\n",
    "import os, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import cm\n",
    "import matplotlib\n",
    "from matplotlib.ticker import ScalarFormatter\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(94)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 50\n",
    "pts = np.random.rand(n,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Building spatially correlated dataset\n",
    "x = pts[:,0] + pts[:,1] - 1\n",
    "colormap = matplotlib.colors.ListedColormap(sns.color_palette(\"coolwarm\",8))\n",
    "plt.scatter(pts[:,0], pts[:,1], c=x, cmap=colormap)\n",
    "plt.clim(-1,1)\n",
    "plt.colorbar()\n",
    "plt.xlabel('x')\n",
    "plt.ylabel('y', rotation='0', labelpad=20)\n",
    "plt.yticks([0,0.5,1])\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating edges\n",
    "knearest = kNN(n_neighbors=6).fit(pts)\n",
    "dists, neighbors = knearest.kneighbors(pts)\n",
    "neighbors = neighbors[:,1:] #get rid of self edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generating samples\n",
    "m = 1000 # num samples\n",
    "k = 25 # size of sample\n",
    "nrecruits = 2 # number of new recruits\n",
    "A = np.zeros((m,n))\n",
    "b = np.ones((m,n)) * 1/n\n",
    "start_node = np.random.randint(low=0,high=n)\n",
    "A[:,start_node] = 1\n",
    "\n",
    "for i in range(m):\n",
    "  sample = set([start_node])\n",
    "  recruiters = set([start_node])\n",
    "  while len(recruiters) > 0 and len(sample) < k:\n",
    "    recruits = set()\n",
    "    for r in recruiters:\n",
    "      possible_recruits = [neigh for neigh in tuple(neighbors[r])]\n",
    "      idx = np.random.choice(len(possible_recruits), size=min(k-len(sample), min(nrecruits, len(possible_recruits))))\n",
    "      recruits.update([possible_recruits[id] for id in tuple(idx)])\n",
    "      sample.update(recruits)\n",
    "      if len(sample) >= k:\n",
    "        break\n",
    "    recruiters = recruits\n",
    "  for elem in sample:\n",
    "    A[i,elem] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot a random sample (gold is root, red is in sample, all points are in target)\n",
    "i = np.random.randint(0,m)\n",
    "plt.scatter(pts[:,0], pts[:,1])\n",
    "plt.scatter(pts[A[i,:] == 1,0], pts[A[i,:] == 1,1], c='red')\n",
    "plt.scatter(pts[start_node,0], pts[start_node,1],c='gold')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot individual probability of being sampled\n",
    "sample_prob = np.sum(A, axis=0) / m\n",
    "mask = sample_prob > 0\n",
    "plt.scatter(pts[:,0], pts[:,1], c='grey')\n",
    "plt.scatter(pts[mask,0], pts[mask,1], c=sample_prob[mask])\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic estimator (sample mean)\n",
    "a_mean = A / np.sum(A, axis=1, keepdims=True)\n",
    "print('Worst-case error: {}'.format(evaluate_weights_grothendieck(a_mean, b)))\n",
    "print('Spatial values error: {}'.format(evaluate_weights(a_mean, b, x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# RCWC estimator\n",
    "a_rcwc = rcwc(A, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Evaluating RCWC results\n",
    "print('Worst-case error: {}'.format(evaluate_weights_grothendieck(a_rcwc, b, is_verbose=False)))\n",
    "print('Spatial values error: {}'.format(evaluate_weights(a_rcwc, b, x)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
