{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8a3768d3",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "68a3422e",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "from numba import jit\n",
    "from scipy.spatial import distance_matrix\n",
    "import os\n",
    "import h5py\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1196046f",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "# Algorithm for L1 matrix vector query\n",
    "\n",
    "def preprocess(X):\n",
    "    return np.argsort(X, axis=0), np.argsort(np.argsort(X, axis=0), axis=0)\n",
    "\n",
    "@jit(nopython=True)\n",
    "def inner_loop(X,order,B,C,n,d):\n",
    "    z = np.zeros(n)\n",
    "    for k in range(n):\n",
    "        for i in range(d):\n",
    "            q = order[k, i]\n",
    "            S1 = B[q,i]\n",
    "            S2 = B[n-1,i] - B[q,i]\n",
    "            S3 = C[q,i]\n",
    "            S4 = C[n-1,i] - C[q,i]\n",
    "            z[k] += X[k,i]*(S3-S4) + S2-S1\n",
    "    return z\n",
    "    \n",
    "\n",
    "def query(X, order1, order2, y):\n",
    "    n,d = X.shape\n",
    "    B = np.take_along_axis((((X.T)*y).T), order1, axis=0).cumsum(axis=0)\n",
    "    C = (y[order1.T].T).cumsum(axis=0)\n",
    "    return inner_loop(X,order2,B,C,n,d)\n",
    "\n",
    "@jit(nopython=True)\n",
    "def naive(X, y):\n",
    "    n,d = X.shape\n",
    "    z = np.zeros(n)\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            z[i] += (np.abs(X[i,:] - X[j,:]).sum())*y[j]\n",
    "    return z\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7876f3e",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "2a9054bd",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2672.535535097122\n"
     ]
    }
   ],
   "source": [
    "# Need to first load the standard MNIST dataset\n",
    "# Time to create MNIST distance matrix\n",
    "start = time.time()\n",
    "dist_matrix_mnist = distance_matrix(train_images[:50000,:],train_images[:50000,:],p=1).astype(np.float32)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "946b6717",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.548295259475708\n"
     ]
    }
   ],
   "source": [
    "# Our preprocessing times\n",
    "start = time.time()\n",
    "order1, order2 = preprocess(train_images[:50000,:])\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "df2ca681",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[40.88059997558594, 38.402283906936646, 38.69560098648071, 37.60749077796936, 37.618812799453735]\n"
     ]
    }
   ],
   "source": [
    "# Query time using the distance matrix\n",
    "total_time = []\n",
    "for i in range(5):\n",
    "    y = np.random.random(50000)\n",
    "    start = time.time()\n",
    "    output1 = dist_matrix_mnist.dot(y)\n",
    "    total_time.append(time.time()-start)\n",
    "print(total_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "3a99341d",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.8774049282073975, 1.834322214126587, 1.8389580249786377, 1.855522871017456, 1.8549790382385254, 1.8368868827819824, 1.8431248664855957, 1.848099946975708, 1.8464722633361816, 1.8517050743103027]\n"
     ]
    }
   ],
   "source": [
    "# Query time of our algorithm\n",
    "total_time2 = []\n",
    "for i in range(10):\n",
    "    y = np.random.random(50000)\n",
    "    start = time.time()\n",
    "    output2 = query(train_images[:50000,:],order1, order2, y)\n",
    "    total_time2.append(time.time()-start)\n",
    "print(total_time2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "060e3e56",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38.64095768928528\n",
      "1.8487476110458374\n"
     ]
    }
   ],
   "source": [
    "print(np.mean([40.88059997558594, 38.402283906936646, 38.69560098648071, 37.60749077796936, 37.618812799453735]))\n",
    "print(np.mean([1.8774049282073975, 1.834322214126587, 1.8389580249786377, 1.855522871017456, 1.8549790382385254, 1.8368868827819824, 1.8431248664855957, 1.848099946975708, 1.8464722633361816, 1.8517050743103027]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e59ce8c9",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# Glove Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "id": "31997ed0",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys: <KeysViewHDF5 ['distances', 'neighbors', 'test', 'train']>\n",
      "(1183514, 50)\n"
     ]
    }
   ],
   "source": [
    "filename = \"glove-50-angular.hdf5\"\n",
    "# Need to download dataset from https://github.com/erikbern/ann-benchmarks/\n",
    "\n",
    "\n",
    "with h5py.File(filename, \"r\") as f:\n",
    "    points = np.array(f['train']) \n",
    "\n",
    "print(points.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e3dd61f0",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "398.64956879615784\n"
     ]
    }
   ],
   "source": [
    "# Naive runtime, only works for a small input!\n",
    "\n",
    "start = time.time()\n",
    "y = np.random.random(50000)\n",
    "output = naive(points[:50000,:], y)\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a7ae7856",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.59389886702963"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Estimate how long naive will take on whole dataset in days\n",
    "(400*(points.shape[0]/50000)**2)/(60*60*24)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e2b306ed",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16.739588022232056\n"
     ]
    }
   ],
   "source": [
    "# Preprocessing for our L1 algorithm\n",
    "start = time.time()\n",
    "order1, order2 = preprocess(points)\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "d4412cf8",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.3747401237487793\n"
     ]
    }
   ],
   "source": [
    "# Print our query time. Note, might need to run this multiple times or first run it on a small sample to properly initialize Numba\n",
    "y = np.random.random(points.shape[0])\n",
    "start = time.time()\n",
    "output2 = query(points,order1, order2, y)\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29a9a90b",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# Gaussian Mixture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "id": "4ec018ae",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(51000, 50)\n"
     ]
    }
   ],
   "source": [
    "# Create dataset\n",
    "points = np.random.normal(size = (17000,50))\n",
    "points = np.vstack((points, np.random.normal(size = (17000,50)) + 5))\n",
    "points = np.vstack((points, np.random.normal(size = (17000,50)) - 5))\n",
    "print(points.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "id": "3af7b4e1",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "453.6680850982666\n"
     ]
    }
   ],
   "source": [
    "# Create distance matrix\n",
    "start = time.time()\n",
    "dist_matrix_mixture = distance_matrix(points, points,p=1).astype(np.float32)\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "b9c6623a",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5509281158447266\n"
     ]
    }
   ],
   "source": [
    "# Preprocess using our algorithm\n",
    "start = time.time()\n",
    "order1, order2 = preprocess(points)\n",
    "print(time.time()-start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "bc40894a",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[45.22181415557861, 40.10519814491272, 45.72833299636841, 40.60417294502258, 44.847196102142334]\n",
      "43.30134286880493\n",
      "[0.15776586532592773, 0.07967019081115723, 0.08495116233825684, 0.07862401008605957, 0.07572197914123535, 0.07488203048706055, 0.07515716552734375, 0.07703185081481934, 0.07836174964904785, 0.076934814453125]\n",
      "0.08591008186340332\n"
     ]
    }
   ],
   "source": [
    "# Compare query times of using the distance matrix vs using our algorithm\n",
    "total_time = []\n",
    "for i in range(5):\n",
    "    y = np.random.random(51000)\n",
    "    start = time.time()\n",
    "    output1 = dist_matrix_mixture.dot(y)\n",
    "    total_time.append(time.time()-start)\n",
    "print(total_time)\n",
    "print(np.mean(total_time))\n",
    "\n",
    "\n",
    "total_time2 = []\n",
    "for i in range(10):\n",
    "    y = np.random.random(51000)\n",
    "    start = time.time()\n",
    "    output2 = query(points,order1, order2, y)\n",
    "    total_time2.append(time.time()-start)\n",
    "print(total_time2)\n",
    "print(np.mean(total_time2))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
