{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "BCDocHHGtK8u"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "import numpy as np\n",
    "import heapq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "s5wrmyyExo29"
   },
   "outputs": [],
   "source": [
    "def sample_gumbel(mu):\n",
    "    \"\"\"Sample a Gumbel(mu).\"\"\"    \n",
    "    return -np.log(np.random.exponential()) + mu\n",
    "\n",
    "\n",
    "def sample_truncated_gumbel(mu, b):\n",
    "    \"\"\"Sample a Gumbel(mu) truncated to be less than b.\"\"\"    \n",
    "    return -np.log(np.random.exponential() + np.exp(-b + mu)) + mu\n",
    "\n",
    "  \n",
    "def sample_gumbel_argmax(logits):\n",
    "    \"\"\"Sample from a softmax distribution over logits.\n",
    "\n",
    "    TODO: check this is correct.\n",
    "\n",
    "    Args:\n",
    "    logits: A flat numpy array of logits.\n",
    "\n",
    "    Returns:\n",
    "    A sample from softmax(logits).\n",
    "    \"\"\"\n",
    "    return np.argmax(-np.log(np.random.exponential(size=logits.shape)) + logits)\n",
    "\n",
    "\n",
    "def logsumexp(logits):\n",
    "    c = np.max(logits)\n",
    "    return np.log(np.sum(np.exp(logits - c))) + c\n",
    "\n",
    "\n",
    "def log_softmax(logits, axis=1):\n",
    "    \"\"\"Normalize logits per row so that they are logprobs.    \n",
    "    \"\"\"\n",
    "    maxes = np.max(logits, axis=axis, keepdims=True)\n",
    "    offset_logits = logits - maxes\n",
    "\n",
    "    log_zs = np.log(np.sum(np.exp(offset_logits), axis=axis, keepdims=True))\n",
    "    return offset_logits - log_zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Policy:\n",
    "    def __init__(self, num_actions):\n",
    "        self.W = np.random.randn(num_actions)\n",
    "        self.b = np.random.randn(num_actions)\n",
    "        self.num_actions = num_actions\n",
    "\n",
    "    def __call__(self,state):\n",
    "        one_hot_state = np.zeros(self.num_actions)\n",
    "        one_hot_state[state]=1\n",
    "\n",
    "        out = self.W*one_hot_state + self.b\n",
    "\n",
    "        return log_softmax(out, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "diD2b31djUAN"
   },
   "outputs": [],
   "source": [
    "# Make a node without a state, and also don't allow `next_actions` to be None.\n",
    "# Just put in all possible next actions when the node is created.\n",
    "Node = collections.namedtuple('Node', \n",
    "                              [\n",
    "                                  'prefix',\n",
    "                                  'logprob_so_far',\n",
    "                                  'max_gumbel', \n",
    "                                  'next_actions',\n",
    "                              ])\n",
    "\n",
    "# Namedtuple for storing results\n",
    "Trajectory = collections.namedtuple('Trajectory', ['actions', 'gumbel'])\n",
    "\n",
    "\n",
    "def sample_trajectory_gumbels(init_state, max_length,num_actions):\n",
    "    \"\"\"Samples an independent Gumbel(logprob) for each trajectory in top-down-ish order.\n",
    "\n",
    "    Args:\n",
    "    action_logprobs: A num_actions array of log probabilities of actions. Here\n",
    "      we assume that the distribution over actions doesn't depend on any state\n",
    "      (so it's independent per timestep).\n",
    "    max_length: Maximum length of a trajectory to allow.\n",
    "    \"\"\"\n",
    "    policy = Policy(num_actions)\n",
    "    action_logprobs = policy(init_state)\n",
    "    \n",
    "    # Start with a node for all trajectories.\n",
    "    root_node = Node(prefix=[], \n",
    "                   logprob_so_far=0,\n",
    "                   max_gumbel=sample_gumbel(0), \n",
    "                   next_actions=range(num_actions))\n",
    "    queue = []\n",
    "    heapq.heappush(queue, root_node)\n",
    "    final_trajectories = []\n",
    "\n",
    "    while queue:\n",
    "        parent = heapq.heappop(queue)  # TODO replace it with priority queue to pop the maximum\n",
    "\n",
    "        if len(parent.prefix) == max_length:\n",
    "            final_trajectories.append(Trajectory(actions=parent.prefix,\n",
    "                                               gumbel=parent.max_gumbel))\n",
    "            continue\n",
    "\n",
    "        # Choose one action from amongst the set of candidates to inherit the max\n",
    "        # gumbel. Call this the \"special\" action.\n",
    "        \n",
    "        current_state = parent.prefix[-1] if len(parent.prefix)>0 else init_state\n",
    "        action_logprobs = policy(current_state)\n",
    "        \n",
    "        next_action_logprobs = action_logprobs[parent.next_actions]\n",
    "        special_action_index = sample_gumbel_argmax(next_action_logprobs)\n",
    "        special_action = parent.next_actions[special_action_index]\n",
    "        special_action_logprob = action_logprobs[special_action]\n",
    "\n",
    "        special_child = Node(prefix=parent.prefix + [special_action],\n",
    "                             logprob_so_far=parent.logprob_so_far + special_action_logprob,\n",
    "                             max_gumbel=parent.max_gumbel, \n",
    "                             next_actions=range(num_actions))  # All next actions are possible.\n",
    "\n",
    "        heapq.heappush(queue,special_child)\n",
    "\n",
    "        # Sample the max gumbel for the non-chosen actions and create an \"other\n",
    "        # children\" node if there are any alternatives left.\n",
    "        other_actions = [i for i in parent.next_actions if i != special_action]\n",
    "\n",
    "        assert len(other_actions) == len(parent.next_actions) - 1\n",
    "\n",
    "        if other_actions:\n",
    "            other_max_location = logsumexp(action_logprobs[other_actions])\n",
    "            other_max_gumbel = sample_truncated_gumbel(parent.logprob_so_far + other_max_location, \n",
    "                                                     parent.max_gumbel)\n",
    "            other_children = Node(prefix=parent.prefix,\n",
    "                                logprob_so_far=parent.logprob_so_far,\n",
    "                                max_gumbel=other_max_gumbel,\n",
    "                                next_actions=other_actions)\n",
    "\n",
    "            heapq.heappush(queue,other_children)\n",
    "    return final_trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected 64 vs actual 64\n",
      "Trajectory(actions=[0, 0, 0], gumbel=-4.5858048132116345)\n",
      "Trajectory(actions=[0, 0, 1], gumbel=-4.425509736272827)\n",
      "Trajectory(actions=[0, 0, 2], gumbel=-6.694908748987112)\n",
      "Trajectory(actions=[0, 0, 3], gumbel=-3.859501004192877)\n",
      "Trajectory(actions=[0, 1, 0], gumbel=-3.5402832531686412)\n",
      "Trajectory(actions=[0, 1, 1], gumbel=0.9782085030670724)\n",
      "Trajectory(actions=[0, 1, 2], gumbel=-4.68739911668616)\n",
      "Trajectory(actions=[0, 1, 3], gumbel=-5.464147651437554)\n",
      "Trajectory(actions=[0, 2, 0], gumbel=-6.354922982709571)\n",
      "Trajectory(actions=[0, 2, 1], gumbel=-6.302920002128424)\n",
      "Trajectory(actions=[0, 2, 2], gumbel=-6.42671518203762)\n",
      "Trajectory(actions=[0, 2, 3], gumbel=-3.8761070531691013)\n",
      "Trajectory(actions=[0, 3, 0], gumbel=-4.462718952794387)\n",
      "Trajectory(actions=[0, 3, 1], gumbel=-1.8819105907406584)\n",
      "Trajectory(actions=[0, 3, 2], gumbel=-5.93267905198373)\n",
      "Trajectory(actions=[0, 3, 3], gumbel=-4.558111131141923)\n",
      "Trajectory(actions=[1, 0, 0], gumbel=-3.41378553766058)\n",
      "Trajectory(actions=[1, 0, 1], gumbel=-3.2217960447179905)\n",
      "Trajectory(actions=[1, 0, 2], gumbel=-7.422652990181401)\n",
      "Trajectory(actions=[1, 0, 3], gumbel=-4.354064901151166)\n",
      "Trajectory(actions=[1, 1, 0], gumbel=-3.1272705768550937)\n",
      "Trajectory(actions=[1, 1, 1], gumbel=-0.40413438322079076)\n",
      "Trajectory(actions=[1, 1, 2], gumbel=-6.175006606323954)\n",
      "Trajectory(actions=[1, 1, 3], gumbel=-1.8888270028447853)\n",
      "Trajectory(actions=[1, 2, 0], gumbel=-5.774677859593021)\n",
      "Trajectory(actions=[1, 2, 1], gumbel=-3.880027503829333)\n",
      "Trajectory(actions=[1, 2, 2], gumbel=-9.441206832769618)\n",
      "Trajectory(actions=[1, 2, 3], gumbel=-6.140831148095233)\n",
      "Trajectory(actions=[1, 3, 0], gumbel=-4.020099102940989)\n",
      "Trajectory(actions=[1, 3, 1], gumbel=-1.8767149000974355)\n",
      "Trajectory(actions=[1, 3, 2], gumbel=-6.85397369345645)\n",
      "Trajectory(actions=[1, 3, 3], gumbel=-4.22883780760061)\n",
      "Trajectory(actions=[2, 0, 0], gumbel=-5.597355396485076)\n",
      "Trajectory(actions=[2, 0, 1], gumbel=-5.687203486202886)\n",
      "Trajectory(actions=[2, 0, 2], gumbel=-9.835885374051086)\n",
      "Trajectory(actions=[2, 0, 3], gumbel=-5.699381705306588)\n",
      "Trajectory(actions=[2, 1, 0], gumbel=-8.314147663935623)\n",
      "Trajectory(actions=[2, 1, 1], gumbel=-3.7483479062191787)\n",
      "Trajectory(actions=[2, 1, 2], gumbel=-6.918697984612872)\n",
      "Trajectory(actions=[2, 1, 3], gumbel=-6.898074888016014)\n",
      "Trajectory(actions=[2, 2, 0], gumbel=-6.899810954479554)\n",
      "Trajectory(actions=[2, 2, 1], gumbel=-7.836327515820941)\n",
      "Trajectory(actions=[2, 2, 2], gumbel=-11.893719105521626)\n",
      "Trajectory(actions=[2, 2, 3], gumbel=-7.889633842684125)\n",
      "Trajectory(actions=[2, 3, 0], gumbel=-3.4128390853194013)\n",
      "Trajectory(actions=[2, 3, 1], gumbel=-5.8986774160510596)\n",
      "Trajectory(actions=[2, 3, 2], gumbel=-8.26790651403305)\n",
      "Trajectory(actions=[2, 3, 3], gumbel=-5.723572172797275)\n",
      "Trajectory(actions=[3, 0, 0], gumbel=-2.8632706596666737)\n",
      "Trajectory(actions=[3, 0, 1], gumbel=-0.21543469768806656)\n",
      "Trajectory(actions=[3, 0, 2], gumbel=-5.874624649238971)\n",
      "Trajectory(actions=[3, 0, 3], gumbel=-1.272666156486469)\n",
      "Trajectory(actions=[3, 1, 0], gumbel=-3.981113153887286)\n",
      "Trajectory(actions=[3, 1, 1], gumbel=-2.6559750351322204)\n",
      "Trajectory(actions=[3, 1, 2], gumbel=-7.3159166594710925)\n",
      "Trajectory(actions=[3, 1, 3], gumbel=-3.5888991510605575)\n",
      "Trajectory(actions=[3, 2, 0], gumbel=-3.7494322571782925)\n",
      "Trajectory(actions=[3, 2, 1], gumbel=-4.56030688465374)\n",
      "Trajectory(actions=[3, 2, 2], gumbel=-6.982575874795728)\n",
      "Trajectory(actions=[3, 2, 3], gumbel=-4.253518975573278)\n",
      "Trajectory(actions=[3, 3, 0], gumbel=-3.4915003338086676)\n",
      "Trajectory(actions=[3, 3, 1], gumbel=-3.2834769529550023)\n",
      "Trajectory(actions=[3, 3, 2], gumbel=-3.735316944963312)\n",
      "Trajectory(actions=[3, 3, 3], gumbel=-2.850237256634672)\n"
     ]
    }
   ],
   "source": [
    "num_actions = 4\n",
    "trajectory_length = 3\n",
    "init_state = 0\n",
    "\n",
    "trajectories = sample_trajectory_gumbels(init_state,\n",
    "                                         trajectory_length,\n",
    "                                         num_actions)\n",
    "\n",
    "print(\"Expected {} vs actual {}\".format(num_actions**trajectory_length, \n",
    "                                        len(trajectories)))\n",
    "for t in trajectories:\n",
    "    print (t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "default_view": {},
   "last_runtime": {
    "build_target": "//learning/dist_belief/client/python:brain_notebook",
    "kind": "private"
   },
   "name": "Autoregressive A* Sampling.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}