{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa193af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "def analyze(path):\n",
    "    with open(\"{}.pkl\".format(path), \"rb+\") as file:\n",
    "        data = pickle.load(file)\n",
    "\n",
    "    # resort\n",
    "    for seed, epi_data in data.items():\n",
    "        for obs_index, obs_ret in epi_data.items():\n",
    "            if obs_index == \"seed\":\n",
    "                continue\n",
    "            resort = sorted(obs_ret, key=lambda x: x[\"neuron\"][\"neuron_index\"], reverse=True)\n",
    "            resort = sorted(resort, key=lambda x: x[\"neuron\"][\"layer\"], reverse=True)\n",
    "            epi_data[obs_index] = resort\n",
    "\n",
    "            # average\n",
    "    averaged_ret = {}\n",
    "    for seed, epi_data in data.items():\n",
    "        for obs, obs_ret in epi_data.items():\n",
    "            if obs == \"seed\":\n",
    "                continue\n",
    "            if obs in averaged_ret:\n",
    "                for k,each_neuron in enumerate(obs_ret):\n",
    "                    assert averaged_ret[obs][k][\"layer\"] == each_neuron[\"neuron\"][\"layer\"] and averaged_ret[obs][k][\"neuron_index\"] == each_neuron[\"neuron\"][\"neuron_index\"], \"{} {}\".format(averaged_ret[obs][k][\"neuron_index\"], each_neuron[\"neuron\"][\"neuron_index\"])\n",
    "                    averaged_ret[obs][k][\"freq_diff\"] = (averaged_ret[obs][k][\"freq_diff\"]*averaged_ret[obs][k][\"seed_num\"] + each_neuron[\"error\"][\"freq_diff\"])/(averaged_ret[obs][k][\"seed_num\"]+1)\n",
    "                    averaged_ret[obs][k][\"correlation\"] = (averaged_ret[obs][k][\"correlation\"]*averaged_ret[obs][k][\"seed_num\"] + each_neuron[\"error\"][\"correlation\"])/(averaged_ret[obs][k][\"seed_num\"]+1)\n",
    "                    averaged_ret[obs][k][\"seed_num\"]+=1\n",
    "            else:\n",
    "                averaged_ret[obs] = []\n",
    "                for each_neuron in obs_ret:\n",
    "                    averaged_ret[obs].append({\"layer\": each_neuron[\"neuron\"][\"layer\"],\n",
    "                                       \"neuron_index\": each_neuron[\"neuron\"][\"neuron_index\"],\n",
    "                                       \"freq_diff\": each_neuron[\"error\"][\"freq_diff\"],\n",
    "                                       \"correlation\": each_neuron[\"error\"][\"correlation\"],\n",
    "                                       \"seed_num\": 1})\n",
    "    for obs_dim, avg_ret in averaged_ret.items():\n",
    "        averaged_ret[obs_dim] = sorted(avg_ret, key=lambda x:x[\"freq_diff\"],reverse=False)\n",
    "    return averaged_ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca8514d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For anymal:\n",
    "averaged_ret = analyze(\"anymal_ret\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d69d3d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"angular velocity related neuron:\", averaged_ret[5][1]) \n",
    "print(\"x-linear velocity related neuron:\", averaged_ret[1][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8364b4cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For MetaDrive:\n",
    "averaged_ret = analyze(\"metadrive_ret\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c03e95be",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"left side distance related neuron:\", averaged_ret[0][0]) \n",
    "print(\"speed related neuron:\", averaged_ret[3][0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
