{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142b0d02",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15436909",
   "metadata": {},
   "outputs": [],
   "source": [
    "rets = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4b1ec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aed80db",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"policy_dissection_ret\"\n",
    "with open(\"{}.pkl\".format(path), \"rb+\") as file:\n",
    "    data = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa193af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# resort\n",
    "for seed, epi_data in data.items():\n",
    "    for obs_index, obs_ret in epi_data.items():\n",
    "        if obs_index == \"seed\":\n",
    "            continue\n",
    "        resort = sorted(obs_ret, key=lambda x: x[\"neuron\"][\"neuron_index\"], reverse=True)\n",
    "        resort = sorted(resort, key=lambda x: x[\"neuron\"][\"layer\"], reverse=True)\n",
    "        epi_data[obs_index] = resort\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9833df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# average\n",
    "averaged_ret = {}\n",
    "for seed, epi_data in data.items():\n",
    "    for obs, obs_ret in epi_data.items():\n",
    "        if obs == \"seed\":\n",
    "            continue\n",
    "        if obs in averaged_ret:\n",
    "            for k,each_neuron in enumerate(obs_ret):\n",
    "                assert averaged_ret[obs][k][\"layer\"] == each_neuron[\"neuron\"][\"layer\"] and averaged_ret[obs][k][\"neuron_index\"] == each_neuron[\"neuron\"][\"neuron_index\"], \"{} {}\".format(averaged_ret[obs][k][\"neuron_index\"], each_neuron[\"neuron\"][\"neuron_index\"])\n",
    "                averaged_ret[obs][k][\"freq_diff\"] = (averaged_ret[obs][k][\"freq_diff\"]*averaged_ret[obs][k][\"seed_num\"] + each_neuron[\"error\"][\"freq_diff\"])/(averaged_ret[obs][k][\"seed_num\"]+1)\n",
    "                averaged_ret[obs][k][\"base_phase_diff\"] = (averaged_ret[obs][k][\"base_phase_diff\"]*averaged_ret[obs][k][\"seed_num\"] + each_neuron[\"error\"][\"base_phase_diff\"])/(averaged_ret[obs][k][\"seed_num\"]+1)\n",
    "                averaged_ret[obs][k][\"seed_num\"]+=1\n",
    "        else:\n",
    "            averaged_ret[obs] = []\n",
    "            for each_neuron in obs_ret:\n",
    "                averaged_ret[obs].append({\"layer\": each_neuron[\"neuron\"][\"layer\"],\n",
    "                                   \"neuron_index\": each_neuron[\"neuron\"][\"neuron_index\"],\n",
    "                                   \"freq_diff\": each_neuron[\"error\"][\"freq_diff\"],\n",
    "                                   \"base_phase_diff\": each_neuron[\"error\"][\"base_phase_diff\"],\n",
    "                                   \"seed_num\": 1})\n",
    "for obs_dim, avg_ret in averaged_ret.items():\n",
    "    averaged_ret[obs_dim] = sorted(avg_ret, key=lambda x:x[\"freq_diff\"],reverse=False)\n",
    "\n",
    "rets[path] = averaged_ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d69d3d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(averaged_ret[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
