{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.realpath('../..'))\n",
    "import toy.ops as ops\n",
    "import toy.data as data\n",
    "import toy.net as net\n",
    "import toy.train as train\n",
    "import toy.ntk as ntk\n",
    "import toy.ground_truth as gt\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhos = [1.0000, 0.9599, 0.800, 0.6457, 0.5, 0.35, 0.1451, 0.0173, 0.0020]\n",
    "\n",
    "\n",
    "generate_dir = '../../experiment/KD_training/Teacher_NN_Stendent_Linear_NN_Infinite_Data'\n",
    "batch_num = 512\n",
    "batch_size = 512\n",
    "zero_net_path = generate_dir + '/zero/network/student'\n",
    "\n",
    "device = torch.device('cuda:0')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "angle_data = []\n",
    "p_alpha = []\n",
    "zero_net = torch.load(zero_net_path, map_location=device)\n",
    "\n",
    "for rho in rhos:\n",
    "    current_dir = generate_dir + '/rho-{:.04f}'.format(rho)\n",
    "\n",
    "    dataset = torch.load(current_dir + '/dataset/train_dataset', map_location=device)\n",
    "    dataset.online = True\n",
    "    dataset.datanum = batch_size\n",
    "    dataset.device = device\n",
    "\n",
    "    init_net = torch.load(current_dir + '/network/init_net', map_location=device)\n",
    "    student_net = torch.load(current_dir + '/network/student', map_location=device)\n",
    "\n",
    "    weight_change = torch.norm(student_net.vec() - zero_net.vec()).item()\n",
    "    print('weight_chage:', weight_change)\n",
    "\n",
    "    data_0 = []\n",
    "    for i in range(batch_size):\n",
    "        with torch.no_grad():\n",
    "            input = dataset[:][1]\n",
    "            Theta = ntk.ReLU_NTK(\n",
    "                input,\n",
    "                weight_std=init_net.weight_std,\n",
    "                bias_std=init_net.bias_std,\n",
    "                hidden_layer_num=init_net.hidden_layer_num,\n",
    "                pointwise=True\n",
    "            )\n",
    "\n",
    "        f = student_net(input, init_net).detach() - zero_net(input, init_net).detach()\n",
    "\n",
    "        angles = torch.acos(torch.abs(f)/torch.sqrt(Theta)/weight_change).cpu().numpy()\n",
    "\n",
    "        data_0.append(angles)\n",
    "    data_0 = np.concatenate(data_0)\n",
    "\n",
    "    density, edge = np.histogram(data_0, bins=100, range=(1.55, np.pi/2))\n",
    "    p_beta = np.cumsum(density)\n",
    "    p_beta = 1 - p_beta/p_beta[-1]\n",
    "\n",
    "    \n",
    "    angle_data.append(p_beta)\n",
    "    p_alpha.append(density)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_dir = '../../experiment/02a'\n",
    "if not os.path.exists(current_dir):\n",
    "    os.makedirs(current_dir)\n",
    "    \n",
    "torch.save(rhos, current_dir + '/rhos')\n",
    "torch.save(angle_data, current_dir + '/p_beta')\n",
    "torch.save(p_alpha, current_dir + '/p_alpha')\n",
    "torch.save(edge[1:], current_dir + '/edge')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}