{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from train_by_reconnect.LaPerm import LaPermTrainLoop\n",
    "from train_by_reconnect.weight_utils import agnosticize\n",
    "from train_by_reconnect.viz_utils import Profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "print(tf.__version__)\n",
    "from tqdm.notebook import trange as nested_progress_bar\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Dense, Flatten\n",
    "from tensorflow.keras.regularizers import l2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = np.expand_dims(x_train.astype('float32')/255.0, -1)\n",
    "x_test = np.expand_dims(x_test.astype('float32')/255.0, -1)\n",
    "\n",
    "# no data augmentation\n",
    "datagen = ImageDataGenerator()\n",
    "datagen.fit(x_train)\n",
    "\n",
    "# training hyper_parameters\n",
    "batch_size = 128\n",
    "learning_rate = 0.001 # initial learning rate\n",
    "tsize = 30000 # size of data for getting the train accuracy\n",
    "vali_freq = 250 # validate per vali_freq batches\n",
    "\n",
    "def lr_scheduler(epoch):\n",
    "    return learning_rate * 0.95 ** epoch\n",
    "\n",
    "def k_scheduler(epoch):\n",
    "    return 250"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Use a shared weight value\n",
    "val = 0.08\n",
    "\n",
    "# Prune the trainable_variable\n",
    "rate = [(4, 7)]\n",
    "\n",
    "# Section 5.5 F_1 model definition\n",
    "F1 = Sequential()\n",
    "F1.add(Flatten(input_shape=(28, 28, 1)))\n",
    "F1.add(Dense(10, activation='softmax', use_bias=False))\n",
    "\n",
    "# Make F1 weight agnostic\n",
    "agnosticize(F1, val=val, prune_ratio=rate)\n",
    "# Confirm if the model is weight agnostic\n",
    "Profiler(F1, skip_1d=False)\n",
    "\n",
    "# Train with LaPerm for 10 epochs\n",
    "epochs = 10\n",
    "loop = LaPermTrainLoop(model=F1, loss='sparse_categorical_crossentropy', inner_optimizer=tf.keras.optimizers.Adam(),\n",
    "                       k_schedule=k_scheduler,\n",
    "                       lr_schedule=lr_scheduler,\n",
    "                       skip_bias=False)\n",
    "loop.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, datagen=datagen,\n",
    "         validation_data=(x_test, y_test), validation_freq=vali_freq, tsize=tsize)\n",
    "\n",
    "# Use the weights with the best validation accuracy\n",
    "F1.set_weights(loop.best_weights)\n",
    "# Confirm accuracy using Keras's evaluate\n",
    "F1.compile(loss='sparse_categorical_crossentropy', metrics=['acc'])\n",
    "F1.evaluate(x_test, y_test)\n",
    "\n",
    "# Confirm if the model is stil weight agnostic\n",
    "Profiler(F1, skip_1d=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Please refer to the cell above for more explanations.\n",
    "val = 0.03\n",
    "rate = [(1, 15), (1, 7), (1, 3)]\n",
    "\n",
    "# Section 5.5 F_2 model definition\n",
    "F2 = Sequential()\n",
    "F2.add(Flatten(input_shape=(28, 28, 1)))\n",
    "F2.add(Dense(128, activation='relu', use_bias=False, kernel_regularizer=l2(1e-4)))\n",
    "F2.add(Dense(64, activation='relu', use_bias=False, kernel_regularizer=l2(1e-4)))\n",
    "F2.add(Dense(10, activation='softmax', use_bias=False))\n",
    "\n",
    "agnosticize(F2, val, rate)\n",
    "Profiler(F2)\n",
    "\n",
    "epochs = 25\n",
    "loop2 = LaPermTrainLoop(model=F2, loss='sparse_categorical_crossentropy', inner_optimizer=tf.keras.optimizers.Adam(),\n",
    "                        k_schedule=k_scheduler,\n",
    "                        lr_schedule=lr_scheduler,\n",
    "                        skip_bias=False)\n",
    "loop2.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, datagen=datagen,\n",
    "          validation_data=(x_test, y_test), validation_freq=vali_freq, tsize=tsize)\n",
    "\n",
    "F2.set_weights(loop2.best_weights)\n",
    "F2.compile(loss='sparse_categorical_crossentropy', metrics=['acc'])\n",
    "F2.evaluate(x_test, y_test)\n",
    "\n",
    "Profiler(F2, skip_1d=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize train and validation accuracies\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(loop._history['val accuracy'], label='Validation Accuracy')\n",
    "plt.plot(loop._history['accuracy'], label='Train Accuracy')\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('Epochs', size=15)\n",
    "plt.ylabel('Accuracy', size=15)\n",
    "plt.legend(prop={'size':15})\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
