{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from train_by_reconnect.LaPerm import LaPermTrainLoop\n",
    "from train_by_reconnect.weight_utils import random_prune\n",
    "from train_by_reconnect.viz_utils import Profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import Conv2D, BatchNormalization, Dropout, Dense, Flatten\n",
    "from tensorflow.keras.preprocessing.image import ImageDataGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training parameters\n",
    "batch_size = 50\n",
    "epochs = 45\n",
    "\n",
    "learning_rate = 0.001  # initial learning rate\n",
    "\n",
    "tsize = 30000  # size of data for getting the train accuracy\n",
    "vali_freq = 1200  # validate per vali_freq batches\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = np.expand_dims(x_train.astype('float32')/255.0, -1)\n",
    "x_test = np.expand_dims(x_test.astype('float32')/255.0, -1)\n",
    "\n",
    "\n",
    "def k_scheduler(epoch):\n",
    "    return 1200\n",
    "\n",
    "\n",
    "def lr_scheduler(epoch):\n",
    "    return 1e-3 * 0.95 ** epoch\n",
    "\n",
    "\n",
    "# data augmentation\n",
    "datagen = ImageDataGenerator(\n",
    "    rotation_range=10,\n",
    "    zoom_range=0.10,\n",
    "    width_shift_range=0.1,\n",
    "    height_shift_range=0.1)\n",
    "datagen.fit(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initializer = 'he_uniform'\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Conv2D(32, kernel_size=3, activation='relu',\n",
    "                 input_shape=(28, 28, 1),\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Conv2D(32, kernel_size=3, activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Conv2D(32, kernel_size=5, strides=2,\n",
    "                 padding='same', activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(64, kernel_size=3, activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Conv2D(64, kernel_size=3, activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Conv2D(64, kernel_size=5, strides=2,\n",
    "                 padding='same', activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Dropout(0.4))\n",
    "\n",
    "model.add(Conv2D(128, kernel_size=4, activation='relu',\n",
    "                 kernel_initializer=initializer,\n",
    "                 bias_initializer=initializer))\n",
    "model.add(Flatten())\n",
    "model.add(Dropout(0.4))\n",
    "model.add(Dense(10, activation='softmax',\n",
    "                kernel_initializer=initializer,\n",
    "                bias_initializer=initializer))\n",
    "\n",
    "# random_prune(model, prune_rate=0.7) # uncomment for random pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "loop = LaPermTrainLoop(model=model, loss='sparse_categorical_crossentropy',\n",
    "                       inner_optimizer=tf.keras.optimizers.Adam(),\n",
    "                       k_schedule=k_scheduler,\n",
    "                       lr_schedule=lr_scheduler,\n",
    "                       skip_bias=False)\n",
    "\n",
    "loop.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, datagen=datagen,\n",
    "         validation_data=(x_test, y_test),\n",
    "         validation_freq=vali_freq,\n",
    "         tsize=tsize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Profiler(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize train and validation accuracies\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(loop._history['val accuracy'], label='Validation Accuracy')\n",
    "plt.plot(loop._history['accuracy'], label='Train Accuracy')\n",
    "plt.grid(linestyle='--')\n",
    "plt.xlabel('Epochs', size=15)\n",
    "plt.ylabel('Accuracy', size=15)\n",
    "plt.legend(prop={'size': 15})\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
