{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute the exact Shapley values of the max function in $O(M^2)$ complexity\n",
    "\n",
    "While the exact Shapley values take $O(2^M)$ time each to compute in general (where $M$ is the number of input features), here we show that for the max function they can be found much more quickly in $O(M^2)$ time to compute all $M$ values. We assume a single reference input, not a whole dataset, so repeating for a set of samples from the dataset would be necessary for computing expectations.\n",
    "\n",
    "Below is the algorithm (in Julia code) with input vector $x$ and a reference vector $r$. It treats the max function as a decision tree and weights each possible binary outcome with the number of permutations that match it. Note that once a value is fixed that is greater than all other values then the decision tree stops branching. By sorting the inputs by the maximum of their input and reference values we can ensure that one branch of the decision tree will stop branching at each step. Once at a leaf in the tree we compute the effects for all the features encountered so far."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "shapley_max (generic function with 1 method)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function shapley_max(x, r)\n",
    "    \n",
    "    # sort so that at each step we know either the input value\n",
    "    # or the reference value of the next feature will be the next largest value\n",
    "    perm = sortperm(collect(zip(x,r)), by=maximum, rev=true)\n",
    "    xsorted = x[perm]\n",
    "    rsorted = r[perm]\n",
    "    \n",
    "    M = length(x)\n",
    "    path = zeros(M)\n",
    "    weight = 1.0\n",
    "    num_ones = 0\n",
    "    phi = zeros(M)\n",
    "    last_val = -Inf\n",
    "    weight_scale = 1.0\n",
    "    \n",
    "    for i in 1:M\n",
    "        largest_remaining = i == M ? -Inf : max(xsorted[i+1], rsorted[i+1])\n",
    "        \n",
    "        if xsorted[i] >= largest_remaining\n",
    "            path[i] = 1\n",
    "            for j in 1:i\n",
    "                if path[j] == 1\n",
    "                    phi[perm[j]] += max(last_val, xsorted[i])*weight*((num_ones+1)/i)/(num_ones+1)\n",
    "                else\n",
    "                    phi[perm[j]] -= max(last_val, xsorted[i])*weight*((num_ones+1)/i)/(i-num_ones-1)\n",
    "                end\n",
    "            end\n",
    "            path[i] = 0\n",
    "            weight_scale = (i-num_ones)/i\n",
    "        end\n",
    "        \n",
    "        if rsorted[i] >= largest_remaining\n",
    "            path[i] = 0\n",
    "            for j in 1:i\n",
    "                if path[j] == 1\n",
    "                    phi[perm[j]] += max(last_val, rsorted[i])*weight*((i-num_ones)/i)/num_ones\n",
    "                else\n",
    "                    phi[perm[j]] -= max(last_val, rsorted[i])*weight*((i-num_ones)/i)/(i-num_ones)\n",
    "                end\n",
    "            end\n",
    "            path[i] = 1\n",
    "            num_ones += 1\n",
    "            weight_scale = num_ones/i\n",
    "        end\n",
    "        \n",
    "        if xsorted[i] >= largest_remaining && rsorted[i] >= largest_remaining\n",
    "            break\n",
    "        end\n",
    "        \n",
    "        last_val = max(min(xsorted[i], rsorted[i]), last_val)\n",
    "        weight *= weight_scale\n",
    "    end\n",
    "    \n",
    "    phi\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation using a brute force Shapley method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "brute_force_shapley (generic function with 1 method)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using Iterators\n",
    "\n",
    "function shapley_weight(M, s)\n",
    "    factorial(s)*factorial(M-s-1)/factorial(M)\n",
    "end\n",
    "function brute_force_shapley(f, x, missing_values, i)\n",
    "    phi = 0\n",
    "    xtmp = zeros(length(missing_values))\n",
    "    for subset in subsets(setdiff(1:length(x), [i]))\n",
    "        xtmp[:] = missing_values\n",
    "        xtmp[subset] = x[subset]\n",
    "        val2 = f(xtmp)\n",
    "        xtmp[i] = x[i]\n",
    "        val1 = f(xtmp)\n",
    "        w = shapley_weight(length(x), length(subset))\n",
    "        phi += w*(val1-val2)\n",
    "    end\n",
    "    phi\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# test 150 random instances of varying sizes\n",
    "for i in 1:15, j in 1:10\n",
    "    x = rand(i)\n",
    "    r = rand(i)\n",
    "    diff = norm(shapley_max(x, r) .- [brute_force_shapley(maximum, x, r, i) for i in 1:length(x)])\n",
    "    @assert diff < 1e-8\n",
    "end"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Julia 0.5.0",
   "language": "julia",
   "name": "julia-0.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
